Skip to content

Conversation

@sugunav14
Copy link
Contributor

@sugunav14 sugunav14 commented Oct 13, 2025

What does this PR do?

Type of change: New example

Overview: This PR adds an e2e example for calibrating a model using a multinode setup and exporting a checkpoint that's ready for deployment using TensorRT-LLM/ vLLM/ SGLang. To enable this, this PR

  1. Allows to load a model and utilize accelerate FSDP2 to shard the model across multiple nodes.
  2. Uses a context manager to make fsdp2 aware weight updates. The context manager gathers full layer, and registers new FSDPParam/QFSDPParam once weights are updated. Finally the gathered layer is redistributed.
  3. Includes unit tests to verify functionality of the context manager for export related functions.
  4. Provides a template fsdp2 accelerate config file. The parameters of the files can be overwritten as per user requirements.
  5. Update README outlining instructions.

Usage

On each node run the following command

accelerate launch --config_file fsdp2.yaml --fsdp_transformer_layer_cls_to_wrap=<model_specific_decoder_layer_name> --num_machines=<num_nodes> --machine_rank=<current_node_machine_rank> --main_process_ip=<node_0_ip_addr> --main_process_port=<port> --num_processes=<total_no_of_gpus_across_nodes>  multinode_ptq.py --pyt_ckpt_path "meta-llama/Meta-Llama-3-8B" --qformat fp8 --batch_size=1 --calib_size=8 --kv_cache_qformat "none" --trust_remote_code

Testing

  • Unit tests to verify functionality of export logic
  • e2e tests in multinode setup
  • Benchmark (Mixtral8x7B fp8 export)
    cpu offloading -> one A6000 -> calibration (batch size = 28) 17mins + export ~127 secs
    multinode -> 4 A6000s (2 on each node) -> calibration (batch size= 68 ) 6mins + export ~340secs
  • Deploy exported checkpoint using vLLM and verify sample prompts
  • Sanity check state dict between non-FSDP exported checkpoint and FSDP exported checkpoint

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Yes

Additional Information

Summary by CodeRabbit

  • New Features

    • End-to-end multi-node PTQ workflow with FSDP2 support, CLI, calibration, quantization, and export.
  • Documentation

    • Added README entry and changelog describing multi-node usage, config template, deployment notes, and a performance caveat.
  • Refactor

    • Introduced reusable quant-config builder and FSDP2-aware utilities to streamline quantization/export flows.
  • Bug Fixes

    • Fixed device-consistent numeric handling in a quantization path.
  • Tests

    • Added GPU-distributed tests covering FSDP2 export, weight-update/compression, fusion quantization, and quantized export.

@sugunav14 sugunav14 requested review from a team as code owners October 13, 2025 23:47
@sugunav14 sugunav14 requested a review from cjluo-nv October 13, 2025 23:47
@sugunav14 sugunav14 marked this pull request as draft October 13, 2025 23:47
@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 13, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 13, 2025

Walkthrough

Adds FSDP2-aware utilities and contexts for coordinated weight updates and dtype patching, integrates Accelerator into export checkpointing, refactors qtensor compression for per-submodule handling, adds multi-node FSDP2 PTQ examples and config, introduces FSDP2 export GPU tests, updates test utilities, and fixes an NVFP4 device-casting bug.

Changes

Cohort / File(s) Summary
Docs & Config
examples/llm_ptq/README.md, examples/llm_ptq/fsdp2.yaml, CHANGELOG.rst
Adds multinode FSDP2 PTQ documentation, a multinode FSDP2 YAML config, and changelog entry. (README addition duplicated in file.)
FSDP2 Quantization Utilities
modelopt/torch/quantization/utils.py
New helpers/contexts: patch_fsdp_mp_dtypes, get_prefixed_param_names, create_fsdp_param_mapping, no_requires_grad, enable_fake_quant, _get_module_name, _get_enclosing_fsdp_module, and fsdp2_aware_weight_update.
Qtensor Base Refactor
modelopt/torch/quantization/qtensor/base_qtensor.py
Removes several internal helpers and replaces bulk FSDP mapping with a per-submodule compression/update loop guarded by fsdp2_aware_weight_update and patch_fsdp_mp_dtypes. Public APIs unchanged.
Export Pipeline
modelopt/torch/export/unified_export_hf.py
Integrates FSDP2-aware contexts around fusion/quant/export steps, accepts **kwargs to obtain an accelerator, prefers accelerator.get_state_dict() when present, and tracks per-FSDP resharding to reduce redundant reshards.
NVFP4 Bugfix
modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Casts weights_scaling_factor_2 to the device of per_block_amax before division to avoid device mismatch in scaling computation.
Multi-Node PTQ Script & Utils
examples/llm_ptq/multinode_ptq.py, examples/llm_ptq/example_utils.py, examples/llm_ptq/hf_ptq.py
Adds end-to-end multinode PTQ CLI script with FSDP2-compatible calibration/export flow and public helpers; adds build_quant_cfg(...) helper and refactors hf_ptq.py to use it.
Tests & Test Utilities
tests/_test_utils/torch_export/export_utils.py, tests/gpu/torch/export/test_export.py, tests/gpu/torch/export/test_fsdp2_export.py
Renames SmallQKVModelSmallLinearModelwithCustomWeight, updates ToyModel to accept bias, adds new SmallQKVModel test class, updates tests to new class, adds debug print, and adds comprehensive multiprocess/NCCL FSDP2 export tests exercising weight-update/compress/fuse/export flows.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant User
    participant Accelerator
    participant FSDP as FSDPModule
    participant Utils as fsdp2_aware_weight_update
    participant Exporter as unified_export_hf

    User->>Accelerator: init (Accelerate + FSDP2)
    Accelerator->>FSDP: wrap model (FSDPModule)

    User->>Exporter: _export_hf_checkpoint(model, ..., accelerator=Accelerator)
    Exporter->>Utils: enter fsdp2_aware_weight_update(model, modules)
    Note right of Utils #D0F0C0: unshard relevant params, patch mp dtypes, map FSDP params
    Utils->>FSDP: perform per-submodule fusion / fake-quant / weight update
    Utils->>FSDP: optional reshard after forward or explicit reshard=False for exports
    Utils-->>Exporter: exit context (restore MP dtypes, reshard if needed)

    Exporter->>Accelerator: accelerator.get_state_dict(model)
    Accelerator-->>Exporter: unified state_dict
    Exporter-->>User: write HF checkpoint / artifacts
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐇 I hopped through shards and patched dtypes bright,
I unshard by moon and reshard by light.
Per-module hops, maps tidy and neat,
Multinode quantized — a carrot-treat! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.19% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The PR title "[OMNIML-2182]: Add example for multinode calibration using FSDP2" accurately captures the primary objective of this changeset. The main additions include a complete multinode PTQ workflow script at examples/llm_ptq/multinode_ptq.py, an FSDP2 configuration template at examples/llm_ptq/fsdp2.yaml, and comprehensive documentation updates to the README. While the PR also includes supporting infrastructure changes such as FSDP2-aware weight update utilities and test coverage, these are all in service of enabling the multinode calibration example. The title is specific, concise, uses clear terminology, and would allow a developer scanning the PR history to immediately understand the main contribution.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch svelury/multinode-ptq

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Oct 14, 2025

Codecov Report

❌ Patch coverage is 36.36364% with 63 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.31%. Comparing base (7ccaa53) to head (b82f95b).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/utils.py 32.60% 62 Missing ⚠️
...odelopt/torch/quantization/qtensor/nvfp4_tensor.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #432      +/-   ##
==========================================
- Coverage   73.38%   73.31%   -0.08%     
==========================================
  Files         180      180              
  Lines       17986    18028      +42     
==========================================
+ Hits        13199    13217      +18     
- Misses       4787     4811      +24     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@sugunav14 sugunav14 self-assigned this Oct 14, 2025
@sugunav14 sugunav14 marked this pull request as ready for review October 14, 2025 18:54
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
examples/llm_ptq/README.md (1)

244-245: Hyphenate the compound modifier

Please change “user specific requirements” to “user-specific requirements” for grammatical correctness.

tests/gpu/torch/export/test_export.py (1)

383-384: Remove leftover debug print

This print statement spams GPU test logs and should be dropped now that the test is stable.

-            print(f"DEBUG LOG: Scale: {scale}, Expected: {expected_amax[0] / maxbound}")
examples/llm_ptq/multinode-ptq.py (1)

245-256: Clarify the comment about model usage.

The comments on lines 253-255 are confusing and potentially contradictory. They state "we should forward pass using the unwrapped model" but then call model(**batch) which is the FSDP-wrapped model.

Consider revising to something like:

# For FSDP2, use the outer FSDP-wrapped model rather than unwrapped_model
# mtq.quantize unwraps the model before passing it to forward_loop as unwrapped_model,
# but we need the FSDP-wrapped version to properly handle DTensor
model(**batch)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 46a9e49 and 067f02d.

📒 Files selected for processing (9)
  • examples/llm_ptq/README.md (1 hunks)
  • examples/llm_ptq/fsdp2.yaml (1 hunks)
  • examples/llm_ptq/multinode-ptq.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (8 hunks)
  • modelopt/torch/quantization/qtensor/base_qtensor.py (3 hunks)
  • modelopt/torch/quantization/utils.py (3 hunks)
  • tests/_test_utils/torch_export/export_utils.py (2 hunks)
  • tests/gpu/torch/export/test_export.py (3 hunks)
  • tests/gpu/torch/export/test_fsdp2_export.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (7)
tests/_test_utils/torch_export/export_utils.py (2)
modelopt/torch/export/model_config.py (1)
  • bias (153-163)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • bias (303-307)
modelopt/torch/quantization/utils.py (2)
modelopt/torch/quantization/qtensor/base_qtensor.py (2)
  • QFSDPParam (139-159)
  • QTensorWrapper (87-136)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/conversion.py (1)
  • set_quantizer_by_cfg_context (300-322)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (591-669)
  • quantizer_attr_names (232-243)
modelopt/torch/export/quant_utils.py (2)
  • preprocess_linear_fusion (940-1010)
  • fuse_prequant_layernorm (926-937)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
  • SmallQKVModel (57-83)
  • ToyModel (20-33)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
modelopt/torch/export/unified_export_hf.py (2)
  • _export_quantized_weight (205-343)
  • requantize_resmooth_fused_llm_layers (88-202)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (591-669)
  • patch_fsdp_mp_dtypes (482-515)
tests/gpu/torch/export/test_export.py (3)
tests/_test_utils/torch_export/export_utils.py (1)
  • SmallLinearModelwithCustomWeight (36-54)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • is_enabled (395-397)
  • maxbound (193-199)
modelopt/torch/export/quant_utils.py (1)
  • get_scaling_factor (212-229)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
modelopt/torch/quantization/utils.py (2)
  • enable_fake_quant (574-587)
  • fsdp2_aware_weight_update (591-669)
examples/llm_ptq/multinode-ptq.py (8)
examples/llm_ptq/example_utils.py (1)
  • apply_kv_cache_quant (245-256)
modelopt/torch/export/convert_hf_config.py (1)
  • convert_hf_quant_config_format (21-117)
modelopt/torch/export/unified_export_hf.py (1)
  • _export_hf_checkpoint (346-511)
modelopt/torch/quantization/config.py (1)
  • need_calibration (1164-1190)
modelopt/torch/quantization/utils.py (2)
  • patch_fsdp_mp_dtypes (482-515)
  • _init_mp_dtypes (485-506)
modelopt/torch/utils/dataset_utils.py (2)
  • get_dataset_dataloader (157-232)
  • get_supported_datasets (235-249)
modelopt/torch/opt/plugins/huggingface.py (1)
  • enable_huggingface_checkpointing (127-162)
modelopt/torch/quantization/model_quant.py (1)
  • print_quant_summary (463-470)
🪛 LanguageTool
examples/llm_ptq/README.md

[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...

(QB_NEW_EN_HYPHEN)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
🔇 Additional comments (17)
modelopt/torch/quantization/qtensor/base_qtensor.py (2)

22-22: LGTM!

The import changes correctly reflect the refactor that moves helper utilities to modelopt/torch/quantization/utils.py. The removal of MixedPrecisionPolicy and contextmanager is appropriate since they are now used in the utilities module.


235-252: LGTM!

The refactor correctly adopts the new per-submodule processing flow using fsdp2_aware_weight_update. This approach is cleaner than the previous FSDPParam mapping reconstruction logic and properly handles weight updates in FSDP2 contexts.

examples/llm_ptq/multinode-ptq.py (4)

56-112: LGTM!

The argument parsing is straightforward and correctly handles comma-separated lists for datasets and calibration sizes. The default handling for None datasets is properly addressed in create_calibration_dataloader.


115-145: LGTM!

The model loading and preparation logic is correct. The dummy optimizer with lr=0.0 is a necessary workaround for FSDP2's requirement to prepare an optimizer alongside the model.


148-176: LGTM!

The calibration dataloader creation is correct. Keeping data on CPU (device=None) and handling device transfer in the calibration loop is the right approach for FSDP2 workflows.


179-223: LGTM!

The quantization configuration logic is well-structured and correctly handles AWQ configuration, KV cache quantization, and model-specific adjustments. The use of copy.deepcopy prevents unintended mutations of the base configuration.

tests/gpu/torch/export/test_fsdp2_export.py (4)

35-69: LGTM!

The test correctly verifies that weight updates made within the fsdp2_aware_weight_update context are properly preserved after unsharding and that the forward pass produces expected results.


72-100: LGTM!

The test correctly verifies that weight compression (including dtype changes to FP8) is properly handled by the fsdp2_aware_weight_update context and preserved after unsharding.


119-214: LGTM!

The tests correctly verify that FSDP2 and non-FSDP2 quantization paths produce equivalent results for layer fusion and weight export. The comparison approach using _compare_parameters_and_buffers is appropriate and thorough.


217-282: LGTM!

The parametrized tests provide comprehensive coverage of various quantization configurations for FSDP2 workflows. The use of spawn_multiprocess_job correctly simulates distributed execution.

Note: Line 244 has a commented configuration (mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG). If this is intentionally disabled due to a known issue, consider adding a comment explaining why.

modelopt/torch/quantization/utils.py (7)

18-36: LGTM!

The new imports are appropriate for FSDP2 support. The use of __future__ annotations and TYPE_CHECKING guard for Generator follows best practices for type hinting.


364-367: LGTM!

The _get_module_name helper is a clean utility for locating a module's name within a root model. The None return for not-found cases is appropriate.


482-515: LGTM with note on global patching.

The function correctly patches FSDP2 to handle mixed dtypes during quantization. The implementation is copied from the latest PyTorch FSDP version for forward compatibility.

Important: This is a global monkey-patch of PyTorch internals. Callers must ensure the original function is restored after use to avoid side effects on other code paths. As noted in the review of examples/llm_ptq/multinode-ptq.py, consider implementing this as a context manager for safer usage.


518-533: LGTM!

The function correctly derives full parameter names by matching parameter identities and stripping the parameter name suffix. The use of id() for matching and rsplit(".", 1)[0] for prefix extraction is appropriate.


536-549: LGTM!

The function correctly builds a mapping from full parameter names to their corresponding FSDPParam objects using the get_prefixed_param_names helper.


552-587: LGTM!

Both context managers are well-implemented workarounds for FSDP2 limitations:

  • no_requires_grad temporarily patches Parameter.__new__ to avoid errors with integer tensors during FSDP2 parameter creation
  • enable_fake_quant prevents weight compression during unshard operations

Both correctly save and restore original state.


590-669: LGTM!

The fsdp2_aware_weight_update context manager is the core FSDP2 weight update utility and is well-implemented. It correctly:

  1. Locates the enclosing FSDP module for the target modules
  2. Unshards weights if needed (with fake_quant protection)
  3. Yields for weight updates
  4. Creates new FSDPParam/QFSDPParam with updated dtype policies
  5. Updates the FSDPParamGroup and reshards

The TODO at line 668 notes a potential performance optimization to conditionally reshard only when necessary. This is a reasonable optimization for future work.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (5)
tests/gpu/torch/export/test_fsdp2_export.py (3)

96-104: Consider verifying shape and values after compression.

The test only checks dtype but not the compressed shape or values. For a comprehensive compression test, verify that param.data.shape == (2, 2) and optionally check that the compressed values are as expected.

Apply this diff to add shape verification:

     torch.distributed.barrier()
     model.linears.unshard()
     # Check if weights are as expected after unshard
     for param in model.parameters():
         assert param.data.dtype == torch.float8_e4m3fn
+        assert param.data.shape == (2, 2), f"Expected shape (2, 2), got {param.data.shape}"

210-213: Remove unnecessary context manager for non-FSDP model.

fsdp2_aware_weight_update is not needed for non_fsdp_model since it's not an FSDPModule. The context manager will no-op but adds unnecessary overhead.

Apply this diff:

     for name, sub_module in non_fsdp_model.named_modules():
         if is_quantlinear(sub_module):
-            with fsdp2_aware_weight_update(non_fsdp_model, sub_module):
-                _export_quantized_weight(sub_module, torch.float16)
+            _export_quantized_weight(sub_module, torch.float16)

240-262: Clarify the commented-out quantization config.

mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG is commented out on Lines 249 and 274. Is this config known to fail, or is it pending implementation? Consider adding a comment explaining why it's disabled or opening an issue to track enabling it.

Do you want me to open a new issue to track enabling this quantization config, or should we add a comment explaining why it's currently disabled?

modelopt/torch/quantization/utils.py (2)

363-367: Consider caching for large models.

Building dict(root_model.named_modules()) on every call could be expensive for large models. If this function is called frequently in a loop, consider caching the mapping or passing it as a parameter.


481-520: Consider documenting the torch version this patch targets.

The docstring mentions copying from "the latest version of torch FSDP" but doesn't specify which version. Consider adding the specific PyTorch version to help future maintainers understand when this patch can be removed.

Example:

-    """Patch FSDP2 to handle mixed dtypes properly during quantization."""
+    """Patch FSDP2 to handle mixed dtypes properly during quantization.
+    
+    This patch is based on PyTorch version X.Y.Z and can be removed once
+    the minimum supported PyTorch version includes this fix.
+    """
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 067f02d and 6e07dff.

📒 Files selected for processing (3)
  • examples/llm_ptq/multinode-ptq.py (1 hunks)
  • modelopt/torch/quantization/utils.py (3 hunks)
  • tests/gpu/torch/export/test_fsdp2_export.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/llm_ptq/multinode-ptq.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
  • SmallQKVModel (57-83)
  • ToyModel (20-33)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
modelopt/torch/export/unified_export_hf.py (2)
  • _export_quantized_weight (205-343)
  • requantize_resmooth_fused_llm_layers (88-202)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (596-676)
  • patch_fsdp_mp_dtypes (482-520)
modelopt/torch/quantization/utils.py (3)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/qtensor/base_qtensor.py (2)
  • QFSDPParam (139-159)
  • QTensorWrapper (87-136)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (6)
tests/gpu/torch/export/test_fsdp2_export.py (2)

33-37: LGTM!

The autouse fixture ensures consistent FSDP mixed precision dtype handling across all tests in this module.


40-75: LGTM!

The test correctly verifies that weight updates within the fsdp2_aware_weight_update context are properly synchronized across FSDP shards.

modelopt/torch/quantization/utils.py (4)

18-35: LGTM!

The new imports support the FSDP2-aware utilities being added, including proper type hints with TYPE_CHECKING to avoid runtime overhead.


578-592: LGTM!

The context manager correctly preserves and restores the _fake_quant state for all weight quantizers in the module hierarchy.


595-629: LGTM!

The context manager entry correctly handles FSDP2-specific setup, including conditional unsharding and FSDPParam mapping creation. The validation ensures all modules belong to the same FSDP group.


630-676: Verify the selective init_dtype_attrs call.

Line 663 calls init_dtype_attrs only for non-QFSDPParam instances. Is this intentional? If QFSDPParam.__init__ already calls this method (as suggested by line 488 in the relevant code snippets), this is correct. However, if it doesn't, or if the behavior differs, this could lead to inconsistencies.

Please confirm whether QFSDPParam.__init__ already handles dtype attribute initialization, or if this selective call is the correct behavior.

Additionally, the TODO on Line 675 mentions a performance optimization for conditional resharding. Consider addressing this if the unconditional reshard() becomes a bottleneck during export.

Comment on lines +557 to +582
@contextmanager
def no_requires_grad():
"""Context manager to temporarily set requires_grad to False.
This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates
a new parameter with default requires_grad and then update the requires_grad attribute as needed. This
triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True
for integer tensors.
"""
original_new = torch.nn.Parameter.__new__

def patched_new(cls, data=None, requires_grad=True):
return original_new(cls, data, requires_grad=False)

torch.nn.Parameter.__new__ = patched_new
try:
yield
finally:
torch.nn.Parameter.__new__ = original_new
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Document thread-safety concerns for this global patch.

Patching torch.nn.Parameter.__new__ globally is invasive and could cause issues if other code creates Parameters concurrently. The patched version also ignores the requires_grad argument, which might surprise callers.

Consider:

  1. Adding a warning in the docstring about thread-safety and concurrent usage
  2. Documenting why this approach is necessary (i.e., FSDP2 integer tensor constraints)
  3. Investigating if there's a less invasive alternative

Example docstring enhancement:

 @contextmanager
 def no_requires_grad():
-    """Context manager to temporarily set requires_grad to False.
+    """Context manager to globally patch Parameter creation to set requires_grad=False.
 
-    This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates
-    a new parameter with default requires_grad and then update the requires_grad attribute as needed. This
-    triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True
-    for integer tensors.
+    **WARNING**: This patches `torch.nn.Parameter.__new__` globally and is NOT thread-safe.
+    Do not use in multi-threaded contexts or when other code might be creating Parameters.
+    
+    This workaround allows calling init_sharded_parameter() on compressed integer weights.
+    FSDP2 creates Parameters with requires_grad=True by default, which errors for integer dtypes.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@contextmanager
def no_requires_grad():
"""Context manager to temporarily set requires_grad to False.
This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates
a new parameter with default requires_grad and then update the requires_grad attribute as needed. This
triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True
for integer tensors.
"""
original_new = torch.nn.Parameter.__new__
def patched_new(cls, data=None, requires_grad=True):
return original_new(cls, data, requires_grad=False)
torch.nn.Parameter.__new__ = patched_new
try:
yield
finally:
torch.nn.Parameter.__new__ = original_new
@contextmanager
def no_requires_grad():
"""Context manager to globally patch Parameter creation to set requires_grad=False.
**WARNING**: This patches `torch.nn.Parameter.__new__` globally and is NOT thread-safe.
Do not use in multi-threaded contexts or when other code might be creating Parameters.
This workaround allows calling init_sharded_parameter() on compressed integer weights.
FSDP2 creates Parameters with requires_grad=True by default, which errors for integer dtypes.
"""
original_new = torch.nn.Parameter.__new__
def patched_new(cls, data=None, requires_grad=True):
return original_new(cls, data, requires_grad=False)
torch.nn.Parameter.__new__ = patched_new
try:
yield
finally:
torch.nn.Parameter.__new__ = original_new
🤖 Prompt for AI Agents
In modelopt/torch/quantization/utils.py around lines 557 to 575, the
no_requires_grad context manager globally patches torch.nn.Parameter.__new__
which is invasive and not thread-safe; update the docstring to (1) clearly warn
that this patch is global and may break concurrent code or other libraries
creating Parameters, (2) explain the specific reason it's required here (FSDP2
creates Parameters with default requires_grad and integer tensors cannot have
requires_grad=True), and (3) note that callers should avoid concurrent Parameter
creation while this context runs and that the patch ignores the requires_grad
argument on purpose. Also add a short TODO suggesting investigation of less
invasive alternatives (e.g., patching only local factories, using a wrapper
function or FSDP API) and ensure the docstring mentions that the original
__new__ is restored in the finally block.

_export_quantized_weight(sub_module, dtype, weight_name)

quantized_state_dict = model.state_dict()
if accelerator is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we infer this flag by looking at the dist ranks etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can. Is it to avoid passing accelerator? In which case we would still need accelerator to gather all the local state dicts and get the full state dict.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/export/unified_export_hf.py (1)

29-29: Avoid hard import of accelerate; make Accelerator optional and lazy

Importing Accelerator at module scope forces a hard dependency on accelerate and breaks environments that don’t use FSDP2. Also, annotating with Accelerator requires accelerate at import time. Make the import lazy and relax the type hint to Any, and guard usage.

Apply this diff:

@@
-import torch.nn as nn
-from accelerate import Accelerator
+import torch.nn as nn
@@
-from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names
+from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names
@@
-def _export_hf_checkpoint(
-    model: nn.Module,
-    dtype: torch.dtype | None = None,
-    accelerator: Accelerator | None = None,
-) -> tuple[dict[str, Any], dict[str, Any]]:
+def _export_hf_checkpoint(
+    model: nn.Module,
+    dtype: torch.dtype | None = None,
+    accelerator: Any | None = None,
+) -> tuple[dict[str, Any], dict[str, Any]]:
@@
-    if accelerator is not None:
-        # Gather state_dict from all ranks
-        quantized_state_dict = accelerator.get_state_dict(model)
+    if accelerator is not None:
+        # Ensure accelerate is available when requested
+        try:
+            from accelerate import Accelerator as _Accelerator  # type: ignore
+        except Exception as e:  # pragma: no cover
+            raise ImportError(
+                "accelerate must be installed to export from an FSDP2-wrapped model."
+            ) from e
+        # Optionally sync before gathering to ensure all ranks finished modification
+        if hasattr(accelerator, "wait_for_everyone"):
+            accelerator.wait_for_everyone()
+        # Gather state_dict from all ranks
+        quantized_state_dict = accelerator.get_state_dict(model)
     else:
         quantized_state_dict = model.state_dict()

Also applies to: 347-350, 496-501

🧹 Nitpick comments (4)
modelopt/torch/export/unified_export_hf.py (1)

118-120: Reduce repeated unshard/reshard by batching fsdp2_aware_weight_update scopes

Each with fsdp2_aware_weight_update may unshard/reshard the same root module repeatedly. Batch module updates per root FSDPModule to cut overhead (especially large models).

  • Collect modules per root FSDP group, then wrap once per group.
  • For MoE loops, accumulate per-expert modules and process in a single context where possible.

Also applies to: 166-168, 177-179, 199-201, 474-476, 493-495

examples/llm_ptq/multinode-ptq.py (3)

243-256: Parameter is unused; clarify intent and avoid confusion

calibrate ignores the unwrapped model parameter but comments are contradictory. Rename the arg to underscore and align the comment.

-def calibrate(unwrapped_model):
-    """Calibration loop that uses the FSDP-wrapped model."""
+def calibrate(_):
+    """Calibration loop that intentionally drives the FSDP-wrapped model for DTensor correctness."""
@@
-            # Use outer model (FSDP-wrapped), not the parameter
-            # Important: We should forward pass using the unwrapped model
-            # mtq.quantize will unwrap the model & pass to the forward_loop
+            # Use the outer FSDP-wrapped model; mtq.quantize passes an unwrapped view,
+            # but we drive the wrapped one for DTensor correctness.
             model(**batch)

125-127: Fix returns docstring to match actual return value

Docstring says two items but function returns three (model, model_type, original_architectures).

-    Returns:
-        Tuple of (prepared_model, model_type)
+    Returns:
+        Tuple of (prepared_model, model_type, original_architectures)

275-283: Synchronize ranks before file I/O

Add a barrier so rank 0 writes only after all ranks finish gathering/modification.

-    post_state_dict, hf_quant_config = _export_hf_checkpoint(
-        model, torch.bfloat16, accelerator=accelerator
-    )
+    post_state_dict, hf_quant_config = _export_hf_checkpoint(
+        model, torch.bfloat16, accelerator=accelerator
+    )
+    if hasattr(accelerator, "wait_for_everyone"):
+        accelerator.wait_for_everyone()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6e07dff and f81c370.

📒 Files selected for processing (2)
  • examples/llm_ptq/multinode-ptq.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/conversion.py (1)
  • set_quantizer_by_cfg_context (300-322)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (596-676)
  • quantizer_attr_names (231-242)
modelopt/torch/export/quant_utils.py (2)
  • preprocess_linear_fusion (940-1010)
  • fuse_prequant_layernorm (926-937)
examples/llm_ptq/multinode-ptq.py (8)
examples/llm_ptq/example_utils.py (1)
  • apply_kv_cache_quant (245-256)
modelopt/torch/export/convert_hf_config.py (1)
  • convert_hf_quant_config_format (21-117)
modelopt/torch/export/unified_export_hf.py (1)
  • _export_hf_checkpoint (346-510)
modelopt/torch/quantization/config.py (1)
  • need_calibration (1164-1190)
modelopt/torch/quantization/utils.py (1)
  • patch_fsdp_mp_dtypes (482-520)
modelopt/torch/utils/dataset_utils.py (2)
  • get_dataset_dataloader (157-232)
  • get_supported_datasets (235-249)
modelopt/torch/opt/plugins/huggingface.py (1)
  • enable_huggingface_checkpointing (127-162)
modelopt/torch/quantization/model_quant.py (1)
  • print_quant_summary (463-470)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename to multinode_ptq.py instead of multinode-ptq.py

DataLoader for calibration
"""
if dataset_names is None:
dataset_names = ["cnn_dailymail"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now we have nemotron_v2 post training dataset, please rebase this PR and add follow the same impl as hf_ptq.py

kv_cfg = getattr(mtq, KV_QUANT_CFG_CHOICES[kv_cache_qformat])["quant_cfg"]
quant_cfg = apply_kv_cache_quant(quant_cfg, kv_cfg)

# Model-specific adjustments
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just remove these for simplicity


@contextmanager
def fsdp2_aware_weight_update(root_model, modules_to_update):
"""Context manager to update the FSDPParam list if an update is made to a submodule of an FSDPModule."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you document briefly what does this section code do under the hood?

Copy link
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the implementation

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/export/unified_export_hf.py (1)

474-511: Reshard the last FSDP module before exiting the loop.

fsdp_module_to_reshard gets updated for every FSDPModule, and the previous module is resharded immediately. The last one, however, is never resharded, so we exit the loop with that module still unsharded. That leaves the full weights resident on each rank for the remainder of the export path, blowing up memory and defeating the communication savings the optimization was supposed to preserve. Please reshard the final module after the loop executes.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f81c370 and 417f17a.

📒 Files selected for processing (4)
  • examples/llm_ptq/README.md (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (10 hunks)
  • modelopt/torch/quantization/utils.py (3 hunks)
  • tests/gpu/torch/export/test_fsdp2_export.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (596-676)
  • quantizer_attr_names (231-242)
modelopt/torch/export/quant_utils.py (3)
  • preprocess_linear_fusion (940-1010)
  • fuse_prequant_layernorm (926-937)
  • get_quantization_format (421-522)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
modelopt/torch/quantization/utils.py (4)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/qtensor/base_qtensor.py (2)
  • QFSDPParam (139-159)
  • QTensorWrapper (87-136)
tests/gpu/torch/export/test_fsdp2_export.py (5)
tests/_test_utils/torch_export/export_utils.py (2)
  • SmallQKVModel (57-83)
  • ToyModel (20-33)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
modelopt/torch/export/unified_export_hf.py (2)
  • _export_quantized_weight (210-348)
  • requantize_resmooth_fused_llm_layers (93-207)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (596-676)
  • patch_fsdp_mp_dtypes (482-520)
modelopt/torch/quantization/qtensor/base_qtensor.py (3)
  • to (114-122)
  • dim (110-112)
  • quantize (67-76)
🪛 LanguageTool
examples/llm_ptq/README.md

[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...

(QB_NEW_EN_HYPHEN)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs


The exported checkpoint can be deployed using TensorRT-LLM/ vLLM/ SGLang. For more details refer to the [deployment section](#deployment) of this document.

> *Performance Note: FSDP2 is designed for training workloads and may result in longer calibration and export times. For faster calibration, maximize the batch size based on available GPU memory.*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you share how long does the calibration take with FSDP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just updated the results in the PR description!

Comment on lines 33 to 47
QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
}

KV_QUANT_CFG_CHOICES = {
"none": "none",
"fp8": "FP8_KV_CFG",
"nvfp4": "NVFP4_KV_CFG",
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why some qformat, like FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, can not be supported compared to the hf_ptq.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of my unit tests were failing with this config. I will look into this issue in a follow-up PR.

Comment on lines 198 to 207
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]

if awq_block_size:
weight_quantizer["block_sizes"][-1] = awq_block_size

# Coarser search for certain models to avoid overflow
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does hf_ptq.py also have something similar? Maybe we can make this a utility to avoid duplication.

Comment on lines 481 to 485
# Every time we encounter a new FSDPModule, we need to reshard the previous one
if fsdp_module_to_reshard is not None:
fsdp_module_to_reshard.reshard()

fsdp_module_to_reshard = sub_module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you share why do we need to reshard the previous layer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to reshard to make sure we don't hit OOM. Every time we update a weight/ run forward the full layer is materialized on the GPU. From my understanding if we don't reshard we could hit OOM.

This is a small trick I used to avoid too many reshard calls therefore speeding up export. Essentially we unshard a decoder layer -> export all linear layers -> then reshard once we are fully done with the decoder layer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for the detailed explaination. Could you also add a few lines of comment in the code to document this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@sugunav14 sugunav14 force-pushed the svelury/multinode-ptq branch from 417f17a to 1378f69 Compare October 21, 2025 18:06
@sugunav14 sugunav14 changed the title Add example for multinode calibration using FSDP2 [OMNIML-2182]: Add example for multinode calibration using FSDP2 Oct 21, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/export/unified_export_hf.py (1)

530-585: Add accelerator parameter to export_hf_checkpoint to support distributed exports

The review comment correctly identifies that _export_hf_checkpoint accepts an accelerator parameter (documented in its docstring at line 366) via **kwargs, but the public export_hf_checkpoint API doesn't expose or pass it through. This forces distributed export scenarios to bypass the public API and call the private helper directly, as seen in examples/llm_ptq/multinode_ptq.py. The suggested diff properly threads the parameter through to enable proper rank-gathering in distributed setups.

modelopt/torch/quantization/qtensor/base_qtensor.py (1)

235-253: Batch eligible leaf modules before resharding to avoid repeated expensive unshard/reshard cycles and assertion failures

The current code calls fsdp2_aware_weight_update for every named submodule in a loop, triggering an unshard/reshard cycle on each iteration. This is inefficient and risks assertion failures on non-leaf modules not present in the FSDP param mapping.

The function signature expects a list of modules (modules_to_update) designed for batch processing with a single reshard cycle. Filter for eligible modules (those with enabled, non-fake quantizers) upfront and pass them as a batch:

-        for _, submodule in fsdp_module.named_modules():
-            with fsdp2_aware_weight_update(fsdp_module, submodule):
-                _compress_and_update_module_weight(submodule)
+        # Collect eligible leaf modules once to avoid repeated reshard/unshard
+        eligible = []
+        for _, submodule in fsdp_module.named_modules():
+            if (
+                hasattr(submodule, "weight")
+                and (submodule.weight is not None and not submodule.weight.is_meta)
+                and hasattr(submodule, "weight_quantizer")
+                and submodule.weight_quantizer.is_enabled
+                and not getattr(submodule.weight_quantizer, "_fake_quant", False)
+                and submodule.weight.element_size() > 1
+            ):
+                eligible.append(submodule)
+        if not eligible:
+            return
+        # Unsharded already above; batch update mapping once and reshard once at exit
+        with fsdp2_aware_weight_update(fsdp_module, eligible, reshard=True):
+            for submodule in eligible:
+                _compress_and_update_module_weight(submodule)

Verify multinode FSDP2 export runs without "Module … not found in fsdp_param_mapping" assertions.

♻️ Duplicate comments (6)
examples/llm_ptq/fsdp2.yaml (1)

1-3: Consumer script documented; config looks sound

The top comment addresses the prior ask; the FSDP2 settings are sane for FULL_STATE_DICT export. Consider adding a brief note that fsdp_transformer_layer_cls_to_wrap is typically overridden per model via CLI.

tests/gpu/torch/export/test_export.py (1)

383-383: Remove debug print from test (already flagged earlier)

The DEBUG LOG line adds noise to CI.

Apply:

-            print(f"DEBUG LOG: Scale: {scale}, Expected: {expected_amax[0] / maxbound}")
tests/gpu/torch/export/test_fsdp2_export.py (3)

37-37: Remove redundant import.

The fully_shard import is already present at line 24. Per past review comments, imports should be at the top of the file.

Apply this diff:

-    from torch.distributed._composable.fsdp import fully_shard
-

75-75: Remove redundant import.

The fully_shard import is already present at line 24.

Apply this diff:

-    from torch.distributed._composable.fsdp import fully_shard
-

163-165: Remove redundant imports.

Both copy (line 163) and fully_shard (line 165) are already imported at the top of the file (lines 17 and 24 respectively).

Apply this diff:

-    import copy
-
-    from torch.distributed._composable.fsdp import fully_shard
-
modelopt/torch/quantization/utils.py (1)

564-582: Enhance docstring with thread-safety warnings and usage constraints.

As noted in a past review comment, this global monkey-patch is invasive and not thread-safe. The docstring should clearly warn about concurrent usage risks and explain why this workaround is necessary (FSDP2's default Parameter creation with requires_grad=True errors on integer dtypes).

Consider enhancing the docstring:

 @contextmanager
 def no_requires_grad():
-    """Context manager to temporarily set requires_grad to False.
+    """Context manager to globally patch Parameter creation to set requires_grad=False.
 
-    This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates
-    a new parameter with default requires_grad and then update the requires_grad attribute as needed. This
-    triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True
-    for integer tensors.
+    **WARNING**: This patches `torch.nn.Parameter.__new__` globally and is NOT thread-safe.
+    Do not use in multi-threaded contexts or when other code might be creating Parameters.
+    
+    This workaround allows calling init_sharded_parameter() on compressed integer weights.
+    FSDP2 creates Parameters with requires_grad=True by default, which errors for integer dtypes.
+    The patch ignores the requires_grad argument and always sets it to False.
+    
+    The original __new__ is restored in the finally block.
     """
🧹 Nitpick comments (12)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1)

214-217: Consider adding defensive device cast for completeness.

While the current fix at line 85 handles most cases, line 216 could benefit from a similar defensive cast to ensure weights_scaling_factor_2 is on the same device as input. This would guard against edge cases where both scaling factors are provided as parameters on different devices.

         # Scale weights
-        scaled_weight = input / (
-            (weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2).unsqueeze(-1)
-        )
+        scaled_weight = input / (
+            (weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2.to(input.device)).unsqueeze(-1)
+        )
examples/llm_ptq/multinode_ptq.py (5)

112-116: Normalize calib_size length for all dataset cases

Currently the length adjustment runs only when args.dataset is None. Users passing multiple datasets will hit an assertion in get_dataset_dataloader if lengths differ. Normalize unconditionally:

-    # Set default dataset if not provided
-    if args.dataset is None:
-        args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
-        warnings.warn(
-            "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
-        )
-        # Adjust calib_size to match dataset length by extending or truncating as needed
-        args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
-            : len(args.dataset)
-        ]
+    # Defaults
+    if args.dataset is None:
+        args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
+        warnings.warn("No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2.")
+    # Adjust calib_size to match dataset length either way
+    args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[: len(args.dataset)]

Also applies to: 289-299


182-214: Calibration loop: clarify param usage and add inference_mode

  • The unwrapped_model parameter is unused, and the comment is contradictory. Use a throwaway name and add torch.inference_mode() to avoid accidental grads.
-def create_fsdp2_calibration_loop(
+def create_fsdp2_calibration_loop(
     model: nn.Module,
     dataloader: torch.utils.data.DataLoader,
     accelerator: Accelerator,
 ):
@@
-    def calibrate(unwrapped_model):
-        """Calibration loop that uses the FSDP-wrapped model."""
-        for batch in tqdm(dataloader, desc="Calibrating"):
+    def calibrate(_):
+        """Calibration loop that runs the FSDP-wrapped model."""
+        with torch.inference_mode():
+            for batch in tqdm(dataloader, desc="Calibrating"):
             ...
-            # Use outer model (FSDP-wrapped), not the parameter
-            # Important: We should forward pass using the unwrapped model
-            # mtq.quantize will unwrap the model & pass to the forward_loop
+            # Use outer FSDP-wrapped model; mtq may pass an unwrapped handle here.
             model(**batch)

233-246: Export dtype: prefer model’s native dtype unless overridden

For robustness across models, let _export_hf_checkpoint default to model.config.torch_dtype:

-    post_state_dict, hf_quant_config = _export_hf_checkpoint(
-        model, torch.bfloat16, accelerator=accelerator
-    )
+    post_state_dict, hf_quant_config = _export_hf_checkpoint(
+        model, accelerator=accelerator
+    )

237-259: Optional: add a barrier after main‑process write

To avoid any chance of other ranks reading a partially written file tree during subsequent steps, insert:

     if accelerator.is_main_process:
         ...
         with open(original_config, "w") as file:
             json.dump(config_data, file, indent=4)
+    accelerator.wait_for_everyone()

349-351: NIT: remove stale log

“Unpatching FSDP2 MP dtypes” is handled by the surrounding context manager; this print is misleading. Safe to drop.

examples/llm_ptq/README.md (1)

244-246: Hyphenation

Use “user‑specific” (hyphenated).

Apply:

-... customized for user specific requirements.
+... customized for user-specific requirements.
examples/llm_ptq/hf_ptq.py (1)

451-451: LGTM: single source of truth for quant_cfg

Using build_quant_cfg(...) here mirrors the multi‑node flow; reduces drift.

Consider adopting the same helper in multinode_ptq.py (already used) and ensuring both scripts share the exact qformat set.

examples/llm_ptq/example_utils.py (1)

64-73: KV‑cache application assumes quant_cfg value and attribute presence

getattr(mtq, kv_quant_cfg_choices[...]) must exist and be a dict with "quant_cfg". Add a defensive check to raise a clear error if missing, or default to no‑op.

tests/_test_utils/torch_export/export_utils.py (1)

57-84: SmallQKVModel ‘device’ attr isn’t used to place parameters/tensors

self.device is stored but modules aren’t moved. Either remove the attribute or move buffers/params to it in init.

modelopt/torch/export/unified_export_hf.py (1)

171-174: fsdp2_aware_weight_update usage looks correct; consider grouping modules to reduce reshard churn

The contexts are applied per group (MoE experts, fused linears, layernorm, experts’ gate/up/down) — good. For large layers you can optionally pass lists (you already do in some paths) and specify reshard=False until group complete to minimize comms.

Also applies to: 182-184, 204-206, 490-511

tests/gpu/torch/export/test_fsdp2_export.py (1)

205-208: Remove unnecessary context manager usage on non-FSDP model.

The fsdp2_aware_weight_update context manager is a no-op when used on non-FSDP models (it only activates if isinstance(root_model, FSDPModule) per line 622 in modelopt/torch/quantization/utils.py). This usage is confusing and should be removed.

Apply this diff:

     for name, sub_module in non_fsdp_model.named_modules():
         if is_quantlinear(sub_module):
-            with fsdp2_aware_weight_update(non_fsdp_model, sub_module):
-                _export_quantized_weight(sub_module, torch.float16)
+            _export_quantized_weight(sub_module, torch.float16)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 417f17a and 1378f69.

📒 Files selected for processing (13)
  • CHANGELOG.rst (1 hunks)
  • examples/llm_ptq/README.md (1 hunks)
  • examples/llm_ptq/example_utils.py (2 hunks)
  • examples/llm_ptq/fsdp2.yaml (1 hunks)
  • examples/llm_ptq/hf_ptq.py (2 hunks)
  • examples/llm_ptq/multinode_ptq.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (10 hunks)
  • modelopt/torch/quantization/qtensor/base_qtensor.py (3 hunks)
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1 hunks)
  • modelopt/torch/quantization/utils.py (3 hunks)
  • tests/_test_utils/torch_export/export_utils.py (2 hunks)
  • tests/gpu/torch/export/test_export.py (3 hunks)
  • tests/gpu/torch/export/test_fsdp2_export.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (10)
tests/_test_utils/torch_export/export_utils.py (4)
modelopt/torch/export/model_config.py (1)
  • bias (153-163)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • bias (303-307)
  • forward (847-946)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)
  • forward (141-163)
modelopt/torch/quantization/backends/nvfp4_gemm.py (1)
  • forward (135-150)
examples/llm_ptq/hf_ptq.py (1)
examples/llm_ptq/example_utils.py (1)
  • build_quant_cfg (42-85)
tests/gpu/torch/export/test_export.py (3)
tests/_test_utils/torch_export/export_utils.py (1)
  • SmallLinearModelwithCustomWeight (36-54)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • is_enabled (395-397)
  • maxbound (193-199)
modelopt/torch/export/quant_utils.py (1)
  • get_scaling_factor (212-229)
examples/llm_ptq/example_utils.py (1)
modelopt/torch/export/model_config.py (1)
  • awq_block_size (289-294)
modelopt/torch/quantization/utils.py (2)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/qtensor/base_qtensor.py (2)
  • QFSDPParam (139-159)
  • QTensorWrapper (87-136)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (2)
modelopt/torch/export/model_config.py (1)
  • weights_scaling_factor_2 (237-266)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (114-122)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
modelopt/torch/quantization/utils.py (2)
  • enable_fake_quant (586-599)
  • fsdp2_aware_weight_update (603-699)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (603-699)
  • quantizer_attr_names (231-242)
modelopt/torch/export/quant_utils.py (3)
  • preprocess_linear_fusion (949-1019)
  • fuse_prequant_layernorm (935-946)
  • get_quantization_format (430-531)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
  • SmallQKVModel (57-83)
  • ToyModel (20-33)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
modelopt/torch/export/unified_export_hf.py (2)
  • _export_quantized_weight (210-348)
  • requantize_resmooth_fused_llm_layers (93-207)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (603-699)
  • patch_fsdp_mp_dtypes (485-527)
examples/llm_ptq/multinode_ptq.py (6)
examples/llm_ptq/example_utils.py (2)
  • build_quant_cfg (42-85)
  • get_tokenizer (95-112)
modelopt/torch/export/unified_export_hf.py (1)
  • _export_hf_checkpoint (351-527)
modelopt/torch/quantization/config.py (1)
  • need_calibration (1164-1190)
modelopt/torch/quantization/utils.py (1)
  • patch_fsdp_mp_dtypes (485-527)
modelopt/torch/utils/dataset_utils.py (1)
  • get_dataset_dataloader (171-246)
modelopt/torch/opt/plugins/huggingface.py (1)
  • enable_huggingface_checkpointing (127-162)
🪛 LanguageTool
examples/llm_ptq/README.md

[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...

(QB_NEW_EN_HYPHEN)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (13)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1)

84-86: LGTM! Device consistency fix for multi-node scenarios.

The explicit device cast ensures weights_scaling_factor_2 matches the device of per_block_amax, preventing device mismatch errors in FSDP2 distributed training. This is critical for the multi-node quantization workflow.

CHANGELOG.rst (1)

14-14: Link/anchor sanity check

Looks good. Please double-check the README anchor remains stable before release branching.

examples/llm_ptq/hf_ptq.py (1)

25-33: Good refactor: centralizing quant_cfg construction

Importing and using build_quant_cfg reduces duplication and keeps behavior consistent.

Please confirm example_utils.build_quant_cfg remains backward compatible with all qformats previously supported by this script.

modelopt/torch/export/unified_export_hf.py (1)

30-36: Good: accelerate import is now optional

Lazy import avoids hard dependency for non‑FSDP2 users. LGTM.

Confirm that environments without accelerate still run export paths and tests unaffected.

tests/gpu/torch/export/test_export.py (1)

309-311: LGTM on model switch in tests

Using SmallLinearModelwithCustomWeight matches the renamed test utility and keeps semantics intact.

Also applies to: 378-380

tests/gpu/torch/export/test_fsdp2_export.py (3)

93-101: Verify shape change in addition to dtype.

The weight is changed from shape (6, 6) to (2, 2) at lines 93-95, but the assertion at line 101 only verifies the dtype change. Consider also asserting the shape to ensure the compression properly updated the weight dimensions.

     for param in model.parameters():
         assert param.data.dtype == torch.float8_e4m3fn
+        assert param.data.shape == (2, 2)

104-118: LGTM!

The helper function correctly compares parameters and buffers between two models using bfloat16 conversion, which is appropriate for quantized model comparisons.


235-247: Document or remove commented test case.

The test case mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG is commented out at line 244 (and again at line 267). Please either remove this line if the test case is not needed, or add a TODO/NOTE comment explaining why it's disabled.

Example:

         mtq.W4A8_AWQ_BETA_CFG,
-        # mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
+        # TODO: Enable once FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG is supported in multi-node setup
+        # mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
         mtq.W4A8_MXFP4_FP8_CFG,
modelopt/torch/quantization/utils.py (5)

18-34: LGTM!

The addition of future annotations, TYPE_CHECKING, and proper type hints follows Python best practices for type annotations.


366-369: LGTM!

The helper function efficiently retrieves the module name using a generator expression with next().


530-561: LGTM!

Both get_prefixed_param_names and create_fsdp_param_mapping are well-implemented with clear documentation and efficient use of generator expressions.


585-599: LGTM!

The enable_fake_quant context manager correctly stores and restores the _fake_quant attribute. The use of a list with pop() is safe because module.modules() returns modules in a consistent order.


602-699: LGTM!

The fsdp2_aware_weight_update context manager provides comprehensive orchestration for FSDP2-aware weight updates, including:

  • Proper unshard/reshard lifecycle management
  • FSDPParam mapping updates with correct dtype policies
  • Intelligent selection between QFSDPParam and FSDPParam based on weight type
  • Cleanup of old hook handles to prevent memory leaks

The implementation correctly handles both FSDP and non-FSDP models by checking isinstance(root_model, FSDPModule) before applying FSDP-specific logic.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
examples/llm_ptq/example_utils.py (1)

42-85: Fix auto-quant path, avoid shared-state mutation, and harden AWQ/KV guards.

  • Auto path: when args.auto_quantize_bits=True, the function returns an empty {} and skips KV-cache quant + model-specific toggles. Implement explicit else branch that at minimum applies KV-cache quant and per-model disables.
  • Shared mutation: for non‑AWQ formats you assign quant_cfg = quant_cfg_choices[...], then mutate it (KV, phi4mm toggles). Always deep‑copy, not only for AWQ.
  • Robustness: guard args.awq_block_size (may be absent) and verify weight_quantizer["block_sizes"] has key -1 before assignment; otherwise KeyError. Also validate kv_cache_qformat mapping and mtq attribute existence before getattr.
  • Note: keeping Q/K/V AWQ block sizes aligned avoids downstream assert in modelopt/torch/export/model_config.py Lines 288‑293. [Relevant snippet referenced]

As echoed in a previous review on this function; this comment consolidates the unresolved parts.

Apply:

@@
-def build_quant_cfg(args, model_type, quant_cfg_choices, kv_quant_cfg_choices):
-    quant_cfg = {}
-    if not hasattr(args, "auto_quantize_bits") or not args.auto_quantize_bits:
+def build_quant_cfg(args, model_type, quant_cfg_choices, kv_quant_cfg_choices):
+    quant_cfg = {}
+    auto = getattr(args, "auto_quantize_bits", False)
+    if not auto:
         assert args.qformat in quant_cfg_choices, (
             f"Unsupported quantization format: {args.qformat} with {args.kv_cache_qformat} KV cache"
         )
-
-        quant_cfg = quant_cfg_choices[args.qformat]
+        # Always work on a private copy to avoid mutating shared choices
+        quant_cfg = copy.deepcopy(quant_cfg_choices[args.qformat])
@@
-        if "awq" in args.qformat:
-            quant_cfg = copy.deepcopy(quant_cfg_choices[args.qformat])
+        if "awq" in args.qformat:
             weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
             if isinstance(weight_quantizer, list):
-                weight_quantizer = weight_quantizer[0]
+                weight_quantizer = weight_quantizer[0]
             # If awq_block_size argument is provided, update weight_quantizer
-            if args.awq_block_size:
-                weight_quantizer["block_sizes"][-1] = args.awq_block_size
+            awq_bs = getattr(args, "awq_block_size", None)
+            if awq_bs:
+                bs = weight_quantizer.get("block_sizes")
+                if isinstance(bs, dict) and (-1 in bs):
+                    bs[-1] = awq_bs
+                else:
+                    # Leave unchanged if structure is absent/mismatched
+                    pass
@@
-        enable_quant_kv_cache = args.kv_cache_qformat != "none"
-        print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
+        enable_quant_kv_cache = getattr(args, "kv_cache_qformat", "none") != "none"
+        print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
@@
-        if enable_quant_kv_cache:
-            quant_cfg = apply_kv_cache_quant(
-                quant_cfg,
-                getattr(mtq, kv_quant_cfg_choices[args.kv_cache_qformat])["quant_cfg"],
-            )
+        if enable_quant_kv_cache:
+            kv_key = args.kv_cache_qformat
+            assert kv_key in kv_quant_cfg_choices, f"Unsupported KV cache quant format: {kv_key}"
+            kv_attr = kv_quant_cfg_choices[kv_key]
+            kv_cfg_holder = getattr(mtq, kv_attr, None)
+            if kv_cfg_holder is None or "quant_cfg" not in kv_cfg_holder:
+                raise ValueError(f"KV cache quant cfg '{kv_attr}' not found in mtq or missing 'quant_cfg'")
+            quant_cfg = apply_kv_cache_quant(quant_cfg, kv_cfg_holder["quant_cfg"])
@@
-        if model_type == "phi4mm":
+        if model_type == "phi4mm":
             # Only quantize the language model
             quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
             quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
             quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
             quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
-
-    return quant_cfg
+    else:
+        # Auto-quant path: still allow KV-cache quantization and per-model toggles
+        enable_quant_kv_cache = getattr(args, "kv_cache_qformat", "none") != "none"
+        print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
+        if enable_quant_kv_cache:
+            kv_key = args.kv_cache_qformat
+            assert kv_key in kv_quant_cfg_choices, f"Unsupported KV cache quant format: {kv_key}"
+            kv_attr = kv_quant_cfg_choices[kv_key]
+            kv_cfg_holder = getattr(mtq, kv_attr, None)
+            if kv_cfg_holder is None or "quant_cfg" not in kv_cfg_holder:
+                raise ValueError(f"KV cache quant cfg '{kv_attr}' not found in mtq or missing 'quant_cfg'")
+            quant_cfg = apply_kv_cache_quant(quant_cfg, kv_cfg_holder["quant_cfg"])
+        if model_type == "phi4mm":
+            quant_cfg.setdefault("quant_cfg", {"default": {"enable": False}})
+            quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
+            quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
+            quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
+            quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
+    return quant_cfg

Run to validate assumptions and spot potential breakages:

#!/bin/bash
set -euo pipefail

echo "Call sites of build_quant_cfg:"
rg -n -C2 '\bbuild_quant_cfg\s*\(' examples/llm_ptq || true

echo
echo "Where auto_quantize_bits is referenced (ensure CLI plumbs it):"
rg -n -C2 '\bauto_quantize_bits\b' || true

echo
echo "KV cfg mapping definitions near callers:"
rg -n -C3 'KV_QUANT_CFG_CHOICES\s*=' examples/llm_ptq || true

echo
echo "AWQ configs mentioning block_sizes (to sanity-check structure):"
rg -n -C1 'block_sizes' examples/llm_ptq || true
🧹 Nitpick comments (1)
examples/llm_ptq/example_utils.py (1)

64-66: Prefer logging over print for user-facing toggles.

Switch to a module logger to control verbosity across multi-node runs.

+import logging
+logger = logging.getLogger(__name__)
@@
-        print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
+        logger.info("%s KV cache quantization", "Enable" if enable_quant_kv_cache else "Disable")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1378f69 and f657534.

📒 Files selected for processing (1)
  • examples/llm_ptq/example_utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/llm_ptq/example_utils.py (1)
modelopt/torch/export/model_config.py (1)
  • awq_block_size (289-294)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
🔇 Additional comments (2)
examples/llm_ptq/example_utils.py (2)

16-16: LGTM on import usage.

copy is used for AWQ deep-copy (and should be used more broadly per fix below).


36-36: LGTM on mtq import.

Explicit aliasing (mtq) reads clearly and matches usage.

SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]


def build_quant_cfg(args, model_type, quant_cfg_choices, kv_quant_cfg_choices):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we avoid a blank args? Please explicitly list the args out

mtq.NVFP4_DEFAULT_CFG,
mtq.NVFP4_AWQ_LITE_CFG,
mtq.W4A8_AWQ_BETA_CFG,
# mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a TODO, same below

grouped_experts = get_experts_list(module, model_type)
for modules in grouped_experts:
preprocess_linear_fusion(modules, resmooth_only=True)
with fsdp2_aware_weight_update(model, modules):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say if we are going to add a new processing. And user forget to add this fsdp2_aware_weight_update line. What will be the error that the user will see?

Copy link
Contributor Author

@sugunav14 sugunav14 Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There wouldn't be an error. But the weight changes will be lost once the module is resharded.

FSDP2 module weights as FSDPParams which has the sharded DTensor and the full tensor which it keeps track of. When we update module.weight it uses the full tensor in FSDPParam .

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/export/unified_export_hf.py (1)

526-556: Expose accelerator in public API and pass through

Allow callers to use the stable public API instead of importing the private helper.

-def export_hf_checkpoint(
-    model: nn.Module,
-    dtype: torch.dtype | None = None,
-    export_dir: Path | str = tempfile.gettempdir(),
-    save_modelopt_state: bool = False,
-):
+def export_hf_checkpoint(
+    model: nn.Module,
+    dtype: torch.dtype | None = None,
+    export_dir: Path | str = tempfile.gettempdir(),
+    save_modelopt_state: bool = False,
+    *,
+    accelerator: Any | None = None,
+):
@@
-        post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
+        post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype, accelerator=accelerator)

I can propagate this change to the multinode script to stop using the private function; say the word.

♻️ Duplicate comments (3)
tests/gpu/torch/export/test_fsdp2_export.py (1)

217-233: Auto‑skip on insufficient GPUs

Previous feedback suggested a fixture for skipping when <2 GPUs. If available, please adopt it for these multiprocess tests to avoid spurious CI failures.

If you want, I can wire a need_2_gpus fixture and parametrize over it.

Also applies to: 249-256, 272-278

examples/llm_ptq/example_utils.py (1)

57-61: Avoid mutating shared quant configs; deep‑copy always

Assigning quant_cfg = quant_cfg_choices[qformat] and then mutating it (KV cache, AWQ tweaks) will corrupt the shared config objects. Always deep‑copy regardless of format.

-        quant_cfg = quant_cfg_choices[qformat]
+        # Always work on a private copy to avoid mutating shared constants
+        quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])

Also remove the redundant deep‑copy in the AWQ branch below (kept by the above change).

modelopt/torch/export/unified_export_hf.py (1)

509-514: Do not silently export incomplete state_dict when FSDP is detected

If any FSDP modules were seen, exporting without an Accelerator will miss sharded params. Raise a clear error instead of silently falling back.

Also reshard the last FSDP module before gathering.

-    if accelerator is not None:
-        # Gather state_dict from all ranks
-        quantized_state_dict = accelerator.get_state_dict(model)
-    else:
-        quantized_state_dict = model.state_dict()
+    # Reshard the last FSDP module, if any
+    if fsdp_module_to_reshard is not None:
+        fsdp_module_to_reshard.reshard()
+
+    if accelerator is not None:
+        # Gather state_dict from all ranks
+        quantized_state_dict = accelerator.get_state_dict(model)
+    else:
+        if fsdp_module_to_reshard is not None:
+            raise ImportError(
+                "FSDP2-wrapped model detected: pass an `accelerator` to export for correct gathering."
+            )
+        quantized_state_dict = model.state_dict()
🧹 Nitpick comments (8)
examples/llm_ptq/README.md (2)

244-247: Hyphenation nit: “user‑specific”

Use a hyphen: “can be customized for user‑specific requirements.”


248-264: Keep README options in sync with the script

Consider either:

  • Listing all supported qformats from the script (int8, int4_awq, fp8, nvfp4, nvfp4_awq, w4a8_awq, fp8_pb_wo, w4a8_mxfp4_fp8, nvfp4_mlp_only), or
  • Add “see script for full list of supported qformats.”

Also consider showing the optional flag present in the script:

  • Append “--awq_block_size ” when demonstrating AWQ.

Would you like me to open a small PR update to expand the options and add --awq_block_size here?

tests/gpu/torch/export/test_fsdp2_export.py (2)

163-165: Remove redundant local import

copy is already imported at module scope; drop the inner import.

-    import copy
-
     from torch.distributed._composable.fsdp import fully_shard

200-209: Speed up test by avoiding per‑layer reshard during export

Use reshard=False in the context to match the production path optimization and reduce test runtime.

-                with fsdp2_aware_weight_update(model, sub_module):
+                with fsdp2_aware_weight_update(model, sub_module, reshard=False):
                     _export_quantized_weight(sub_module, torch.float16)
examples/llm_ptq/example_utils.py (3)

59-67: Harden AWQ block_size edit

Guard for expected structure to avoid KeyError on configs without [-1] in block_sizes.

-        if "awq" in qformat:
-            quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
+        if "awq" in qformat:
             weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
             if isinstance(weight_quantizer, list):
                 weight_quantizer = weight_quantizer[0]
             # If awq_block_size argument is provided, update weight_quantizer
             if awq_block_size:
-                weight_quantizer["block_sizes"][-1] = awq_block_size
+                bs = weight_quantizer.get("block_sizes", {})
+                if isinstance(bs, dict) and (-1 in bs or not bs):
+                    bs[-1] = awq_block_size
+                    weight_quantizer["block_sizes"] = bs

82-85: Condition will never match here; scope correctly

This branch checks "int8_sq" in qformat, but qformat is a single token and QUANT_CFG_CHOICES in the multinode flow doesn’t include int8_sq. Either:

  • Move this tweak to the place where int8_sq is actually supported, or
  • Guard by choices: if model_type == "gemma" and qformat == "int8_sq" and "int8_sq" in quant_cfg_choices: ...
-        if model_type == "gemma" and "int8_sq" in qformat:
+        if model_type == "gemma" and qformat == "int8_sq" and "int8_sq" in quant_cfg_choices:
             quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}

51-93: Optional: handle auto‑quant path explicitly

When auto_quantize is True, this returns {}. If build_quant_cfg is reused by auto‑quant users, consider returning a minimal config and applying KV‑cache quant via apply_kv_cache_quant similarly to the non‑auto path.

I can draft the small else branch if you plan to reuse this for auto‑quant flows.

examples/llm_ptq/multinode_ptq.py (1)

319-326: Pass explicit boolean for auto_quantize flag

Use False instead of None for clarity.

-        None,
+        False,
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f657534 and a4c7a5a.

📒 Files selected for processing (6)
  • examples/llm_ptq/README.md (1 hunks)
  • examples/llm_ptq/example_utils.py (2 hunks)
  • examples/llm_ptq/hf_ptq.py (2 hunks)
  • examples/llm_ptq/multinode_ptq.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (10 hunks)
  • tests/gpu/torch/export/test_fsdp2_export.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/llm_ptq/hf_ptq.py
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (603-699)
  • quantizer_attr_names (231-242)
modelopt/torch/export/quant_utils.py (3)
  • preprocess_linear_fusion (949-1019)
  • fuse_prequant_layernorm (935-946)
  • get_quantization_format (430-531)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
  • SmallQKVModel (57-83)
  • ToyModel (20-33)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
modelopt/torch/export/unified_export_hf.py (2)
  • _export_quantized_weight (205-343)
  • requantize_resmooth_fused_llm_layers (88-202)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (603-699)
  • patch_fsdp_mp_dtypes (485-527)
examples/llm_ptq/example_utils.py (3)
modelopt/torch/export/model_config.py (1)
  • awq_block_size (289-294)
examples/llm_ptq/hf_ptq.py (1)
  • auto_quantize (94-152)
modelopt/torch/quantization/model_quant.py (1)
  • auto_quantize (234-450)
examples/llm_ptq/multinode_ptq.py (8)
examples/llm_ptq/example_utils.py (2)
  • build_quant_cfg (42-93)
  • get_tokenizer (103-120)
modelopt/torch/export/convert_hf_config.py (1)
  • convert_hf_quant_config_format (21-117)
modelopt/torch/export/unified_export_hf.py (1)
  • _export_hf_checkpoint (346-523)
modelopt/torch/quantization/config.py (1)
  • need_calibration (1164-1190)
modelopt/torch/quantization/utils.py (1)
  • patch_fsdp_mp_dtypes (485-527)
modelopt/torch/utils/dataset_utils.py (1)
  • get_dataset_dataloader (171-246)
modelopt/torch/opt/plugins/huggingface.py (1)
  • enable_huggingface_checkpointing (127-162)
modelopt/torch/quantization/model_quant.py (1)
  • print_quant_summary (463-470)
🪛 LanguageTool
examples/llm_ptq/README.md

[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...

(QB_NEW_EN_HYPHEN)

Comment on lines +24 to +25
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
from modelopt.torch.quantization.config import need_calibration
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Prefer public export API over private helper

After exposing accelerator in export_hf_checkpoint, switch to the public API.

-from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
+from modelopt.torch.export.unified_export_hf import export_hf_checkpoint

And below (see next comment) call the public function.

🤖 Prompt for AI Agents
In examples/llm_ptq/multinode_ptq.py around lines 24-25, the code imports the
private helper _export_hf_checkpoint; change this to import the public
export_hf_checkpoint from modelopt.torch.export.unified_export_hf and update any
subsequent calls to use the public name. Ensure the import line uses the public
function and that later code calls export_hf_checkpoint with the same arguments
(including accelerator) instead of the underscored helper.

post_state_dict, hf_quant_config = _export_hf_checkpoint(
model, torch.bfloat16, accelerator=accelerator
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Call the public API with accelerator

Use export_hf_checkpoint(..., accelerator=accelerator) once the public function accepts it.

-    post_state_dict, hf_quant_config = _export_hf_checkpoint(
-        model, torch.bfloat16, accelerator=accelerator
-    )
+    post_state_dict, hf_quant_config = export_hf_checkpoint(
+        model, torch.bfloat16, export_dir=export_dir, save_modelopt_state=False, accelerator=accelerator
+    )
+    # Note: if you let export_hf_checkpoint write files directly, you can simplify the save steps below.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/llm_ptq/multinode_ptq.py around lines 233 to 235, the code calls the
internal helper _export_hf_checkpoint(..., accelerator=accelerator); replace
this with the public API export_hf_checkpoint(..., accelerator=accelerator) once
that function accepts the accelerator parameter, ensure the public function is
imported at the top of the file (remove or stop using the underscored helper),
pass the same arguments including accelerator=accelerator, and update any
references or tests that expect the internal function if necessary.

Comment on lines +30 to 31
from torch.distributed.fsdp import FSDPModule

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make FSDP import robust across torch versions

Guard the import and avoid hard failures on environments without composable FSDP. Also gate isinstance checks.

-from torch.distributed.fsdp import FSDPModule
+try:
+    # Prefer composable FSDP module type when available
+    from torch.distributed._composable.fsdp import FSDPModule as _FSDPModule  # PyTorch ≥2.3
+except Exception:  # pragma: no cover
+    _FSDPModule = None

And below, use _FSDPModule guarded checks (see further diff).

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines 462 to +483
for name, sub_module in layer_pool.items():
# Optimization to perform resharding only once per decoder layer to avoid extra communication overhead
if isinstance(sub_module, FSDPModule):
# Every time we encounter a new FSDPModule, the previous decoder layer is fully processed.
# We need to reshard the previous FSDPModule to prevent potential OOM.
# This hack reduces the number of unshard reshard operations, to avoid unnecessary communication.
if fsdp_module_to_reshard is not None:
fsdp_module_to_reshard.reshard()

fsdp_module_to_reshard = sub_module

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard isinstance check to avoid NameError when FSDP is unavailable

Use the guarded _FSDPModule from the import refactor.

-        if isinstance(sub_module, FSDPModule):
+        if _FSDPModule is not None and isinstance(sub_module, _FSDPModule):
             # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed.
             # We need to reshard the previous FSDPModule to prevent potential OOM.
             # This hack reduces the number of unshard reshard operations, to avoid unnecessary communication.
             if fsdp_module_to_reshard is not None:
                 fsdp_module_to_reshard.reshard()
 
             fsdp_module_to_reshard = sub_module
🤖 Prompt for AI Agents
In modelopt/torch/export/unified_export_hf.py around lines 473 to 483, the
isinstance check uses FSDPModule which can raise NameError when FSDP isn't
available; replace that check to use the guarded import _FSDPModule (the
refactored alias) so the code safely skips FSDP-specific logic when absent, and
ensure the module imports the _FSDPModule symbol instead of FSDPModule or falls
back to None before the loop.

@sugunav14 sugunav14 force-pushed the svelury/multinode-ptq branch from d51b652 to 3b80298 Compare October 22, 2025 04:04
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (8)
tests/gpu/torch/export/test_export.py (1)

378-390: Remove debug print statement.

The debug print on line 383 was flagged in a previous review as unnecessary.

Apply this diff:

     for name, module in model.named_modules():
         if isinstance(module, TensorQuantizer) and module.is_enabled:
             scale = get_scaling_factor(module)
-            print(f"DEBUG LOG: Scale: {scale}, Expected: {expected_amax[0] / maxbound}")
             assert torch.allclose(
                 scale,
                 torch.tensor((expected_amax[0] / maxbound), dtype=scale.dtype),
                 rtol=1e-3,
                 atol=1e-3,
             )
             expected_amax.pop(0)
modelopt/torch/quantization/utils.py (1)

564-582: Document thread-safety concerns and the reason for this global patch.

The no_requires_grad context manager globally patches torch.nn.Parameter.__new__, which is invasive and not thread-safe. This should be clearly documented in the docstring to warn users.

Based on past review feedback, enhance the docstring to warn about the global nature and thread-safety:

 @contextmanager
 def no_requires_grad():
-    """Context manager to temporarily set requires_grad to False.
+    """Context manager to globally patch Parameter creation to set requires_grad=False.
 
-    This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates
-    a new parameter with default requires_grad and then update the requires_grad attribute as needed. This
-    triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True
-    for integer tensors.
+    **WARNING**: This patches `torch.nn.Parameter.__new__` globally and is NOT thread-safe.
+    Do not use in multi-threaded contexts or when other code might be creating Parameters concurrently.
+    
+    This workaround allows calling init_sharded_parameter() on compressed integer weights.
+    FSDP2 creates Parameters with requires_grad=True by default, which errors for integer dtypes.
+    The patch ignores the requires_grad argument and forces it to False. The original __new__ is
+    restored in the finally block.
     """
examples/llm_ptq/example_utils.py (1)

42-93: Critical: Fix implicit None return and prevent shared config mutation.

The function has two critical issues flagged in past reviews:

  1. Implicit None return: When auto_quantize=True, the function never returns (no code path after line 52's if block), yielding None to callers who then invoke need_calibration(None) causing crashes.

  2. Shared config mutation: Line 57 uses direct assignment instead of deep copy for non-AWQ formats. Subsequent modifications (lines 77-91) then mutate the shared quant_cfg_choices dictionary, polluting state across multiple calls.

Apply this fix:

 def build_quant_cfg(
     qformat,
     kv_cache_qformat,
     awq_block_size,
     auto_quantize,
     model_type,
     quant_cfg_choices,
     kv_quant_cfg_choices,
 ):
     quant_cfg = {}
     if not auto_quantize:
         assert qformat in quant_cfg_choices, (
             f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache"
         )
 
-        quant_cfg = quant_cfg_choices[qformat]
+        # Always work on a private copy to avoid mutating shared config
+        quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
 
         if "awq" in qformat:
-            quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
             weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
             if isinstance(weight_quantizer, list):
                 weight_quantizer = weight_quantizer[0]
             # If awq_block_size argument is provided, update weight_quantizer
             if awq_block_size:
                 weight_quantizer["block_sizes"][-1] = awq_block_size
 
             # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
             if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
                 quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
 
         enable_quant_kv_cache = kv_cache_qformat != "none"
         print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
 
         # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
         if enable_quant_kv_cache:
             quant_cfg = apply_kv_cache_quant(
                 quant_cfg,
                 getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
             )
 
         # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
         if model_type == "gemma" and "int8_sq" in qformat:
             quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
 
         if model_type == "phi4mm":
             # Only quantize the language model
             quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
             quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
             quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
             quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
-
-    return quant_cfg
+        
+        return quant_cfg
+    else:
+        # Auto-quantize path: return empty config or apply model-specific settings
+        # The actual format selection happens in auto_quantize()
+        if model_type == "phi4mm":
+            quant_cfg["quant_cfg"] = {
+                "default": {"enable": False},
+                "*speech*": {"enable": False},
+                "*audio*": {"enable": False},
+                "*image*": {"enable": False},
+                "*vision*": {"enable": False},
+            }
+        return quant_cfg
tests/gpu/torch/export/test_fsdp2_export.py (1)

217-233: Use a fixture to skip when <2 GPUs (already noted earlier)

Adopt the repo’s multi‑GPU skip fixture to avoid spurious CI failures.

Would you like me to patch this to use a need_2_gpus fixture and parametrize accordingly?

examples/llm_ptq/multinode_ptq.py (2)

24-25: Avoid private export helper; prefer public API once accelerator is supported

Use export_hf_checkpoint(...) and pass accelerator when the public function accepts it; simplifies save flow.

-from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
+from modelopt.torch.export.unified_export_hf import export_hf_checkpoint
@@
-    post_state_dict, hf_quant_config = _export_hf_checkpoint(
-        model, torch.bfloat16, accelerator=accelerator
-    )
+    post_state_dict, hf_quant_config = export_hf_checkpoint(
+        model, torch.bfloat16, export_dir=export_dir, save_modelopt_state=False, accelerator=accelerator
+    )
+    # If export_hf_checkpoint writes files directly, you can drop the manual save steps below.

Also applies to: 233-235


60-76: Verify README/options consistency for qformat/KV‑cache flags

Ensure README uses the same flag names and values (e.g., --kv_cache_qformat vs any aliases like --kv_cache_quant; presence of int8_sq if documented).

I can update either the README or CLI to add aliases/choices once you confirm intended names.

modelopt/torch/export/unified_export_hf.py (2)

509-514: Don’t silently fallback when exporting FSDP models without Accelerator

Falling back to model.state_dict() produces incomplete state for sharded models. Error out if FSDP is detected but no accelerator is provided.

Apply:

-    if accelerator is not None:
-        # Gather state_dict from all ranks
-        quantized_state_dict = accelerator.get_state_dict(model)
-    else:
-        quantized_state_dict = model.state_dict()
+    if accelerator is not None:
+        quantized_state_dict = accelerator.get_state_dict(model)
+    else:
+        if fsdp_module_to_reshard is not None:
+            raise ImportError(
+                "accelerate must be installed and an Accelerator instance provided "
+                "to export from an FSDP2-wrapped model."
+            )
+        quantized_state_dict = model.state_dict()

473-483: Reshard the last FSDP module after the loop

The final FSDP module remains unsharded and holds extra memory. Add a trailing reshard.

     for name, sub_module in layer_pool.items():
         ...
         if _FSDPModule is not None and isinstance(sub_module, _FSDPModule):
             ...
             fsdp_module_to_reshard = sub_module

-    if accelerator is not None:
+    # Reshard the last FSDP module, if any
+    if fsdp_module_to_reshard is not None:
+        fsdp_module_to_reshard.reshard()
+
+    if accelerator is not None:
         ...

Also applies to: 509-513

🧹 Nitpick comments (7)
modelopt/torch/quantization/utils.py (1)

484-527: Enhance documentation for the FSDP2 dtype patch.

While the docstring mentions this is copied from the latest torch FSDP repository, it would benefit from additional context explaining:

  1. Why this patch is necessary (relaxing uniform dtype requirement during quantization)
  2. The scope/lifetime of the patch (applied during context, restored after)
  3. Any potential side effects or limitations

Based on learnings from past reviews.

Consider enhancing the docstring:

 @contextmanager
 def patch_fsdp_mp_dtypes():
     """Patch FSDP2 to handle mixed dtypes properly during quantization.
 
     This patch is used to relax the requirement of uniform original parameter dtype in FSDP2 and is
     copied from the latest torch FSDP repository `torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py <https://github.com/pytorch/pytorch/blob/c40048472cc4e28f44e8e835cae319add231bf5/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py#L227>`_.
+    
+    This is necessary because during quantization workflows with FSDP2, parameters may temporarily
+    have non-uniform dtypes (e.g., mixing float and quantized integer tensors). The original FSDP2
+    implementation enforces uniform dtypes, which would cause assertion errors. This context manager
+    temporarily replaces the _init_mp_dtypes method to allow non-uniform dtypes, then restores the
+    original implementation when exiting the context.
     """
examples/llm_ptq/README.md (1)

238-269: Comprehensive multi-node FSDP2 documentation with minor grammar fix.

The documentation clearly explains the FSDP2 multi-node setup, command usage, and deployment options. One minor grammar improvement on line 244.

Apply this diff to fix the grammar:

-For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements.
+For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user-specific requirements.
tests/_test_utils/torch_export/export_utils.py (1)

57-83: Document the TODO or file an issue for FSDP2 layernorm bias issue.

The TODO comment on line 68 mentions that FSDP2 modifies the bias of layernorm for AWQ. This should either be:

  1. Documented with more context (is this expected behavior or a bug?)
  2. Tracked in a separate issue

Do you want me to help create an issue to track this FSDP2/AWQ layernorm bias behavior, or should the TODO be expanded with more context about whether this is expected?

modelopt/torch/quantization/qtensor/base_qtensor.py (2)

22-22: Make FSDP import/version handling robust

Composable FSDP lives under torch.distributed._composable.fsdp in recent PyTorch. Guard imports and isinstance checks to avoid hard failures on older/newer torch.

Apply:

-from torch.distributed.fsdp import FSDPModule, fully_shard
+try:
+    # Preferred: composable FSDP (PyTorch ≥2.3)
+    from torch.distributed._composable.fsdp import FSDPModule as _FSDPModule, fully_shard
+except Exception:  # pragma: no cover
+    _FSDPModule = None

And below:

-            if isinstance(m, FSDPModule):
+            if _FSDPModule is not None and isinstance(m, _FSDPModule):
                 _compress_fsdp_module(m)

Also applies to: 257-259


235-236: Local import OK; add brief comment for rationale

Lazy-importing utils avoids import cycles; add a one‑liner comment for future readers.

-        from modelopt.torch.quantization.utils import enable_fake_quant, fsdp2_aware_weight_update
+        # Lazy import to avoid circulars and keep FSDP2 helpers isolated
+        from modelopt.torch.quantization.utils import enable_fake_quant, fsdp2_aware_weight_update
modelopt/torch/export/unified_export_hf.py (1)

30-31: Guard FSDP import across torch versions

Avoid hard dependency on composable FSDP symbol location.

-from torch.distributed.fsdp import FSDPModule
+try:
+    from torch.distributed._composable.fsdp import FSDPModule as _FSDPModule
+except Exception:  # pragma: no cover
+    _FSDPModule = None

And where used:

-        if isinstance(sub_module, FSDPModule):
+        if _FSDPModule is not None and isinstance(sub_module, _FSDPModule):
tests/gpu/torch/export/test_fsdp2_export.py (1)

163-168: Remove duplicate local import

Redundant import copy inside function; already imported at the top.

-    import copy
-
     from torch.distributed._composable.fsdp import fully_shard
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a4c7a5a and 3b80298.

📒 Files selected for processing (13)
  • CHANGELOG.rst (1 hunks)
  • examples/llm_ptq/README.md (1 hunks)
  • examples/llm_ptq/example_utils.py (2 hunks)
  • examples/llm_ptq/fsdp2.yaml (1 hunks)
  • examples/llm_ptq/hf_ptq.py (2 hunks)
  • examples/llm_ptq/multinode_ptq.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (10 hunks)
  • modelopt/torch/quantization/qtensor/base_qtensor.py (3 hunks)
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1 hunks)
  • modelopt/torch/quantization/utils.py (3 hunks)
  • tests/_test_utils/torch_export/export_utils.py (2 hunks)
  • tests/gpu/torch/export/test_export.py (3 hunks)
  • tests/gpu/torch/export/test_fsdp2_export.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
🧰 Additional context used
🧬 Code graph analysis (9)
tests/gpu/torch/export/test_export.py (3)
tests/_test_utils/torch_export/export_utils.py (1)
  • SmallLinearModelwithCustomWeight (36-54)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • is_enabled (395-397)
  • maxbound (193-199)
modelopt/torch/export/quant_utils.py (1)
  • get_scaling_factor (212-229)
examples/llm_ptq/hf_ptq.py (1)
examples/llm_ptq/example_utils.py (1)
  • build_quant_cfg (42-93)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
modelopt/torch/quantization/utils.py (2)
  • enable_fake_quant (586-599)
  • fsdp2_aware_weight_update (603-699)
modelopt/torch/quantization/utils.py (2)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/qtensor/base_qtensor.py (2)
  • QFSDPParam (139-159)
  • QTensorWrapper (87-136)
tests/_test_utils/torch_export/export_utils.py (2)
modelopt/torch/export/model_config.py (1)
  • bias (153-163)
modelopt/torch/quantization/backends/nvfp4_gemm.py (1)
  • forward (135-150)
examples/llm_ptq/example_utils.py (3)
modelopt/torch/export/model_config.py (1)
  • awq_block_size (289-294)
examples/llm_ptq/hf_ptq.py (1)
  • auto_quantize (94-152)
modelopt/torch/quantization/model_quant.py (1)
  • auto_quantize (234-450)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (603-699)
  • quantizer_attr_names (231-242)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
modelopt/torch/export/unified_export_megatron.py (1)
  • state_dict (465-469)
examples/llm_ptq/multinode_ptq.py (8)
examples/llm_ptq/example_utils.py (2)
  • build_quant_cfg (42-93)
  • get_tokenizer (103-120)
modelopt/torch/export/convert_hf_config.py (1)
  • convert_hf_quant_config_format (21-117)
modelopt/torch/export/unified_export_hf.py (1)
  • _export_hf_checkpoint (346-523)
modelopt/torch/quantization/config.py (1)
  • need_calibration (1164-1190)
modelopt/torch/quantization/utils.py (1)
  • patch_fsdp_mp_dtypes (485-527)
modelopt/torch/utils/dataset_utils.py (1)
  • get_dataset_dataloader (171-246)
modelopt/torch/opt/plugins/huggingface.py (1)
  • enable_huggingface_checkpointing (127-162)
modelopt/torch/quantization/model_quant.py (1)
  • print_quant_summary (463-470)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
  • SmallQKVModel (57-83)
  • ToyModel (20-33)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
modelopt/torch/export/unified_export_hf.py (2)
  • _export_quantized_weight (205-343)
  • requantize_resmooth_fused_llm_layers (88-202)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (603-699)
  • patch_fsdp_mp_dtypes (485-527)
🪛 LanguageTool
examples/llm_ptq/README.md

[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...

(QB_NEW_EN_HYPHEN)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (20)
examples/llm_ptq/fsdp2.yaml (1)

1-30: LGTM!

The FSDP2 configuration is well-structured and clearly documented. The comment on line 2 appropriately addresses the file's usage, and the settings align with the multi-node PTQ workflow described in the PR objectives.

modelopt/torch/quantization/utils.py (6)

18-35: LGTM!

The addition of TYPE_CHECKING, Generator, and __future__ annotations improves type hinting and code organization.


366-369: LGTM!

The _get_module_name helper function provides a clean way to locate a module's qualified name within the root model.


530-545: LGTM!

The get_prefixed_param_names function correctly derives full parameter names by checking parameter identity against the parent model.


548-561: LGTM!

The create_fsdp_param_mapping function provides a clean mapping from module names to FSDPParam objects.


585-599: LGTM!

The enable_fake_quant context manager correctly toggles the _fake_quant attribute on weight quantizers.


602-699: LGTM with excellent documentation!

The fsdp2_aware_weight_update context manager is well-documented with clear parameter descriptions and behavior explanation. The implementation correctly handles unsharding, FSDPParam updates, and optional resharding.

CHANGELOG.rst (1)

15-15: LGTM!

The changelog entry accurately documents the new FSDP2 multi-node PTQ support and provides a helpful link to the README for details.

examples/llm_ptq/hf_ptq.py (2)

27-27: LGTM!

The import of build_quant_cfg is appropriate for the centralized config building approach.


451-459: Excellent refactoring to centralize config building.

Replacing the inline quant_cfg construction with the build_quant_cfg function call improves maintainability and reduces code duplication across PTQ scripts.

tests/gpu/torch/export/test_export.py (2)

19-19: LGTM!

The import update reflects the model refactoring in export_utils.py.


309-309: LGTM!

The model instantiation correctly uses the renamed SmallLinearModelwithCustomWeight class.

tests/_test_utils/torch_export/export_utils.py (2)

21-30: LGTM!

Adding the bias parameter to ToyModel provides useful testing flexibility.


36-54: LGTM!

The renaming to SmallLinearModelwithCustomWeight makes the purpose of this test model clearer.

examples/llm_ptq/example_utils.py (2)

16-16: LGTM!

The copy import is necessary for deep-copying configuration dictionaries.


36-36: LGTM!

The import of modelopt.torch.quantization as mtq is necessary for accessing quantization configurations.

modelopt/torch/export/unified_export_hf.py (1)

118-201: Wrap fusions in minimal scopes; tiny doc tweak

The fsdp2_aware_weight_update usage here looks correct. Add a short comment why reshard=False is used to reduce comms.

If desired, I can add a brief code comment summarizing the unshard-once/reshard-once rationale across this block.

tests/gpu/torch/export/test_fsdp2_export.py (1)

195-215: Good parity checks between sharded and non‑sharded paths

State/buffer comparisons are robust (bf16). Nice coverage.

examples/llm_ptq/multinode_ptq.py (2)

182-215: Calibration loop uses outer FSDP model intentionally

Using the wrapped model in the closure is correct for DTensor paths.


279-287: Top‑level Accelerator import is acceptable for this example

Given FSDP2 is required here, the hard dependency is fine.

Comment on lines 250 to 253
for _, submodule in fsdp_module.named_modules():
with fsdp2_aware_weight_update(fsdp_module, submodule):
_compress_and_update_module_weight(submodule)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Don’t wrap every submodule; filter and reshard once

Calling fsdp2_aware_weight_update on all named_modules() will assert for non‑mapped modules and causes repeated un/reshard. Filter to eligible modules, batch them in one context with reshard=False, then reshard once.

Apply:

-        for _, submodule in fsdp_module.named_modules():
-            with fsdp2_aware_weight_update(fsdp_module, submodule):
-                _compress_and_update_module_weight(submodule)
+        # Compress only eligible submodules and reshard once at the end
+        eligible: list[torch.nn.Module] = []
+        for _, submodule in fsdp_module.named_modules():
+            if (
+                hasattr(submodule, "weight")
+                and submodule.weight is not None
+                and not getattr(submodule.weight, "is_meta", False)
+                and hasattr(submodule, "weight_quantizer")
+                and getattr(submodule.weight_quantizer, "is_enabled", False)
+                and not getattr(submodule.weight_quantizer, "_fake_quant", False)
+                and submodule.weight.element_size() > 1
+            ):
+                eligible.append(submodule)
+
+        if not eligible:
+            fsdp_module.reshard()
+            return
+
+        with fsdp2_aware_weight_update(fsdp_module, eligible, reshard=False):
+            for submodule in eligible:
+                _compress_and_update_module_weight(submodule)
+
+        fsdp_module.reshard()

Committable suggestion skipped: line range outside the PR's diff.

Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
Signed-off-by: Suguna Velury <[email protected]>
@sugunav14 sugunav14 force-pushed the svelury/multinode-ptq branch from 54fd2b5 to b82f95b Compare October 22, 2025 09:32
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/export/unified_export_hf.py (1)

473-508: Reshard the final FSDP module to avoid memory leak.

The loop reshards fsdp_module_to_reshard only when a new FSDPModule is encountered (lines 479-480), leaving the final module unresharded after the loop completes. This keeps the full unsharded weights in memory.

Apply this diff after line 507:

                     with fsdp2_aware_weight_update(model, sub_module, reshard=False):
                         _export_quantized_weight(sub_module, dtype, weight_name)
 
+    # Reshard the last FSDP module if any
+    if fsdp_module_to_reshard is not None:
+        fsdp_module_to_reshard.reshard()
+
     if accelerator is not None:
♻️ Duplicate comments (3)
modelopt/torch/export/unified_export_hf.py (1)

509-513: Guard against FSDP export without accelerator.

When FSDP modules are detected (fsdp_module_to_reshard is not None) but accelerator is None, the code falls back to model.state_dict() which cannot properly gather sharded state across ranks, producing incomplete results.

Apply this diff:

     if accelerator is not None:
         # Gather state_dict from all ranks
         quantized_state_dict = accelerator.get_state_dict(model)
     else:
+        if fsdp_module_to_reshard is not None:
+            raise RuntimeError(
+                "FSDP2-wrapped model detected but no accelerator provided. "
+                "Pass accelerator=<Accelerator instance> to export_hf_checkpoint for FSDP2 export."
+            )
         quantized_state_dict = model.state_dict()
tests/gpu/torch/export/test_export.py (1)

383-383: Remove debug print statement.

The debug print on line 383 appears to be leftover from development and should be removed before merging.

Apply this diff:

             scale = get_scaling_factor(module)
-            print(f"DEBUG LOG: Scale: {scale}, Expected: {expected_amax[0] / maxbound}")
             assert torch.allclose(
tests/gpu/torch/export/test_fsdp2_export.py (1)

37-37: Remove redundant imports.

fully_shard is already imported at line 24, and copy at line 17. Remove the duplicate imports inside these functions.

Apply this diff:

 def _update_weight_test(rank, size):
     """Test fsdp2 weight update context for weight update -> only value changed"""
-    from torch.distributed._composable.fsdp import fully_shard
-
     with patch_fsdp_mp_dtypes():
 def _compress_weight_test(rank, size):
     """Test fsdp2 weight update context for weight compression -> only value,shape and dtype changed"""
-    from torch.distributed._composable.fsdp import fully_shard
-
     with patch_fsdp_mp_dtypes():
 def _export_quantized_weight_test(rank, size, quant_config):
-    import copy
-
-    from torch.distributed._composable.fsdp import fully_shard
-
     with patch_fsdp_mp_dtypes():

Also applies to: 75-75, 163-165

🧹 Nitpick comments (6)
tests/_test_utils/torch_export/export_utils.py (1)

36-36: Fix capitalization in class name.

The class name SmallLinearModelwithCustomWeight has inconsistent capitalization. The "with" should be capitalized to follow PEP 8 naming conventions for classes.

Apply this diff:

-class SmallLinearModelwithCustomWeight(torch.nn.Module):
+class SmallLinearModelWithCustomWeight(torch.nn.Module):

Also update the import in tests/gpu/torch/export/test_export.py line 19 to match.

modelopt/torch/quantization/qtensor/base_qtensor.py (1)

221-229: Consider filtering submodules before wrapping to reduce overhead.

The current implementation calls fsdp2_aware_weight_update for every non-root submodule. While the context manager handles non-FSDP cases, filtering to only eligible modules (those with weights and enabled quantizers) before entering the context would reduce overhead and avoid unnecessary context manager invocations.

Consider filtering like:

with (
    SequentialQuantizer.convert_to_single_quantizer(module),
    torch.no_grad(),
    patch_fsdp_mp_dtypes(),
):
    eligible_modules = [
        m for name, m in module.named_modules()
        if name != "" 
        and hasattr(m, "weight") 
        and m.weight is not None 
        and not getattr(m.weight, "is_meta", False)
    ]
    
    for m in eligible_modules:
        with fsdp2_aware_weight_update(module, m):
            _compress_and_update_module_weight(m)
tests/gpu/torch/export/test_fsdp2_export.py (1)

205-208: Remove unnecessary FSDP2 context manager for non-FSDP model.

non_fsdp_model is not FSDP-wrapped, so wrapping the export call in fsdp2_aware_weight_update is unnecessary and misleading. The context manager will just yield without performing any FSDP-specific operations.

Apply this diff:

     for name, sub_module in non_fsdp_model.named_modules():
         if is_quantlinear(sub_module):
-            with fsdp2_aware_weight_update(non_fsdp_model, sub_module):
-                _export_quantized_weight(sub_module, torch.float16)
+            _export_quantized_weight(sub_module, torch.float16)
examples/llm_ptq/multinode_ptq.py (3)

199-210: Clarify parameter naming and comments in calibration loop.

The unwrapped_model parameter is not used, and the comment "We should forward pass using the unwrapped model" contradicts the actual code which uses the outer closure variable model. While this appears intentional for FSDP2 compatibility, the parameter name and comments are confusing.

Consider either:

  • Renaming the parameter to _unwrapped_model to indicate it's intentionally unused
  • Updating the comment to clarify why the outer model is used instead
-    def calibrate(unwrapped_model):
-        """Calibration loop that uses the FSDP-wrapped model."""
+    def calibrate(_unwrapped_model):
+        """Calibration loop that uses the outer FSDP-wrapped model for DTensor compatibility."""
         for batch in tqdm(dataloader, desc="Calibrating"):
             if isinstance(batch, dict):
                 batch = {
                     k: v.to(accelerator.device) if isinstance(v, torch.Tensor) else v
                     for k, v in batch.items()
                 }
-            # Use outer model (FSDP-wrapped), not the parameter
-            # Important: We should forward pass using the unwrapped model
-            # mtq.quantize will unwrap the model & pass to the forward_loop
+            # Use outer FSDP-wrapped model (from closure) for proper DTensor handling
+            # mtq.quantize passes unwrapped model to forward_loop, but FSDP2 requires the wrapped version
             model(**batch)

288-297: Simplify and safeguard calib_size adjustment logic.

The one-liner for adjusting calib_size is complex and could fail with an IndexError if args.calib_size is somehow empty. Consider making it more explicit and adding a safeguard.

Apply this diff:

     # Set default dataset if not provided
     if args.dataset is None:
         args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
         warnings.warn(
             "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
         )
-        # Adjust calib_size to match dataset length by extending or truncating as needed
-        args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
-            : len(args.dataset)
-        ]
+        # Extend calib_size to match dataset count by repeating the last value
+        if len(args.calib_size) < len(args.dataset):
+            last_size = args.calib_size[-1] if args.calib_size else 512
+            args.calib_size.extend([last_size] * (len(args.dataset) - len(args.calib_size)))
+        elif len(args.calib_size) > len(args.dataset):
+            args.calib_size = args.calib_size[: len(args.dataset)]

360-360: Guard print statement with main process check.

This print will execute on all processes in a multi-node/multi-GPU setup, causing duplicate output. Add an is_main_process guard for cleaner logging.

Apply this diff:

-    print("Unpatching FSDP2 MP dtypes")
+    if accelerator.is_main_process:
+        print("Unpatching FSDP2 MP dtypes")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 54fd2b5 and b82f95b.

📒 Files selected for processing (13)
  • CHANGELOG.rst (1 hunks)
  • examples/llm_ptq/README.md (1 hunks)
  • examples/llm_ptq/example_utils.py (2 hunks)
  • examples/llm_ptq/fsdp2.yaml (1 hunks)
  • examples/llm_ptq/hf_ptq.py (2 hunks)
  • examples/llm_ptq/multinode_ptq.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (10 hunks)
  • modelopt/torch/quantization/qtensor/base_qtensor.py (2 hunks)
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py (1 hunks)
  • modelopt/torch/quantization/utils.py (3 hunks)
  • tests/_test_utils/torch_export/export_utils.py (2 hunks)
  • tests/gpu/torch/export/test_export.py (3 hunks)
  • tests/gpu/torch/export/test_fsdp2_export.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
  • examples/llm_ptq/hf_ptq.py
  • CHANGELOG.rst
🧰 Additional context used
🧬 Code graph analysis (8)
examples/llm_ptq/multinode_ptq.py (7)
examples/llm_ptq/example_utils.py (2)
  • build_quant_cfg (42-93)
  • get_tokenizer (103-120)
modelopt/torch/export/convert_hf_config.py (1)
  • convert_hf_quant_config_format (21-117)
modelopt/torch/export/unified_export_hf.py (1)
  • _export_hf_checkpoint (346-523)
modelopt/torch/quantization/config.py (1)
  • need_calibration (1164-1190)
modelopt/torch/quantization/utils.py (1)
  • patch_fsdp_mp_dtypes (485-527)
modelopt/torch/utils/dataset_utils.py (1)
  • get_dataset_dataloader (171-246)
modelopt/torch/opt/plugins/huggingface.py (1)
  • enable_huggingface_checkpointing (127-162)
modelopt/torch/quantization/qtensor/base_qtensor.py (2)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (603-705)
  • patch_fsdp_mp_dtypes (485-527)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • SequentialQuantizer (1114-1222)
  • convert_to_single_quantizer (1197-1222)
tests/gpu/torch/export/test_export.py (3)
tests/_test_utils/torch_export/export_utils.py (1)
  • SmallLinearModelwithCustomWeight (36-54)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • is_enabled (395-397)
  • maxbound (193-199)
modelopt/torch/export/quant_utils.py (1)
  • get_scaling_factor (212-229)
modelopt/torch/export/unified_export_hf.py (2)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (603-705)
  • quantizer_attr_names (231-242)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
modelopt/torch/quantization/utils.py (2)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/qtensor/base_qtensor.py (2)
  • QFSDPParam (139-159)
  • QTensorWrapper (87-136)
examples/llm_ptq/example_utils.py (2)
examples/llm_ptq/hf_ptq.py (1)
  • auto_quantize (94-152)
modelopt/torch/quantization/model_quant.py (1)
  • auto_quantize (234-450)
tests/_test_utils/torch_export/export_utils.py (3)
modelopt/torch/export/model_config.py (1)
  • bias (153-163)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)
  • forward (141-163)
modelopt/torch/quantization/backends/nvfp4_gemm.py (1)
  • forward (135-150)
tests/gpu/torch/export/test_fsdp2_export.py (4)
tests/_test_utils/torch_export/export_utils.py (2)
  • SmallQKVModel (57-83)
  • ToyModel (20-33)
modelopt/torch/export/layer_utils.py (1)
  • is_quantlinear (346-348)
modelopt/torch/export/unified_export_hf.py (2)
  • _export_quantized_weight (205-343)
  • requantize_resmooth_fused_llm_layers (88-202)
modelopt/torch/quantization/utils.py (2)
  • fsdp2_aware_weight_update (603-705)
  • patch_fsdp_mp_dtypes (485-527)
🪛 LanguageTool
examples/llm_ptq/README.md

[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...

(QB_NEW_EN_HYPHEN)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (3)
examples/llm_ptq/fsdp2.yaml (1)

1-30: LGTM! Configuration is well-documented.

The FSDP2 configuration is properly documented (line 2) and includes all necessary settings for multinode PTQ workflows.

modelopt/torch/quantization/utils.py (2)

484-527: LGTM! FSDP2 dtype patching is well-implemented.

The patch_fsdp_mp_dtypes context manager correctly patches and restores FSDP's _init_mp_dtypes to handle mixed-precision dtypes during quantization. The implementation follows the upstream PyTorch pattern and properly cleans up in the finally block.


602-705: LGTM! FSDP2 weight update coordination is correctly implemented.

The fsdp2_aware_weight_update context manager properly coordinates weight updates across FSDP2 shards by:

  1. Unsharding when needed
  2. Tracking FSDPParam mappings
  3. Creating new FSDPParam/QFSDPParam instances after updates
  4. Optionally resharding to release memory

The implementation correctly handles edge cases and provides the necessary synchronization for distributed PTQ workflows.

Comment on lines +57 to +80
quant_cfg = quant_cfg_choices[qformat]

if "awq" in qformat:
quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
# If awq_block_size argument is provided, update weight_quantizer
if awq_block_size:
weight_quantizer["block_sizes"][-1] = awq_block_size

# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}

enable_quant_kv_cache = kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = apply_kv_cache_quant(
quant_cfg,
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Always deep-copy quant_cfg to prevent shared state mutation.

Non-AWQ formats use a shallow reference (line 57), but apply_kv_cache_quant (lines 77-80) mutates quant_cfg["quant_cfg"]. This pollutes the shared quant_cfg_choices module object across calls, similar to the issue flagged in past reviews.

Apply this diff:

-        quant_cfg = quant_cfg_choices[qformat]
+        quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
 
         if "awq" in qformat:
-            quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
             weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
🤖 Prompt for AI Agents
In examples/llm_ptq/example_utils.py around lines 57 to 80, the code assigns
quant_cfg = quant_cfg_choices[qformat] by reference for non-AWQ formats but
later mutate quant_cfg via apply_kv_cache_quant, which leaks into the shared
quant_cfg_choices; change the initial assignment to always make a deep copy
(e.g., quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])) so both AWQ and
non-AWQ branches operate on an isolated copy before any in-place updates
(preserve the existing extra deepcopy in the AWQ branch if present).


### Usage

For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix hyphenation for compound modifier.

The phrase "user specific requirements" should use a hyphen when the compound modifier precedes the noun.

Apply this diff:

-For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements.
+For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user-specific requirements.
🧰 Tools
🪛 LanguageTool

[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...

(QB_NEW_EN_HYPHEN)

🤖 Prompt for AI Agents
In examples/llm_ptq/README.md around line 244, the phrase "user specific
requirements" should be hyphenated as "user-specific requirements" when used as
a compound modifier before the noun; update the sentence to use the hyphenated
form to fix the compound modifier hyphenation.

@sugunav14 sugunav14 merged commit b8dbfc0 into main Oct 22, 2025
27 checks passed
@sugunav14 sugunav14 deleted the svelury/multinode-ptq branch October 22, 2025 17:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants