Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Model Optimizer Changelog (Linux)
- Add support for ``nemotron-post-training-dataset-v2`` and ``nemotron-post-training-dataset-v1`` in ``examples/llm_ptq``. Default to a mix of ``cnn_dailymail`` and ``nemotron-post-training-dataset-v2`` (gated dataset accessed using ``HF_TOKEN`` environment variable) if no dataset is specified.
- Allow specifying ``calib_seq`` in ``examples/llm_ptq`` to set the maximum sequence length for calibration.
- Add support for MCore MoE PTQ/QAT/QAD.
- Add support for multi-node PTQ and export with FSDP2 in ``examples/llm_ptq/multinode_ptq.py``. See `examples/llm_ptq/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/llm_ptq#multi-node-post-training-quantization-with-fsdp2>`_ for more details.

**Documentation**

Expand Down
32 changes: 32 additions & 0 deletions examples/llm_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,38 @@ with init_quantized_weights(mtq.NVFP4_DEFAULT_CFG):
mtq.calibrate(model, algorithm="max", forward_loop=calibrate_loop)
```

## Multi-Node Post-Training Quantization with FSDP2

ModelOpt enables quantization of LLMs across multiple GPU nodes using various quantization formats. It leverages HuggingFace's Accelerate library and FSDP2 for distributed model sharding and calibration.

### Usage

For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix hyphenation for compound modifier.

The phrase "user specific requirements" should use a hyphen when the compound modifier precedes the noun.

Apply this diff:

-For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements.
+For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user-specific requirements.
🧰 Tools
🪛 LanguageTool

[grammar] ~244-~244: Use a hyphen to join words.
Context: ... provided and can be customized for user specific requirements. On each node run...

(QB_NEW_EN_HYPHEN)

🤖 Prompt for AI Agents
In examples/llm_ptq/README.md around line 244, the phrase "user specific
requirements" should be hyphenated as "user-specific requirements" when used as
a compound modifier before the noun; update the sentence to use the hyphenated
form to fix the compound modifier hyphenation.


On each node run the following command:

```bash
accelerate launch --config_file fsdp2.yaml \
--num_machines=<num_nodes> \
--machine_rank=<current_node_rank> \
--main_process_ip=<node0_ip_addr> \
--main_process_port=<port> \
--fsdp_transformer_layer_cls_to_wrap=<decoder_layer_name>
multinode_ptq.py \
--pyt_ckpt_path <path_to_model> \
--qformat <fp8/nvfp4/nvfp4_awq/int8> \
--kv_cache_qformat <fp8/nvfp4/nvfp4_affine/none> \
--batch_size <calib_batch_size> \
--calib_size <num_calib_samples> \
--dataset <dataset> \
--export_path <export_path> \
--trust_remote_code
```

The exported checkpoint can be deployed using TensorRT-LLM/ vLLM/ SGLang. For more details refer to the [deployment section](#deployment) of this document.

> *Performance Note: FSDP2 is designed for training workloads and may result in longer calibration and export times. For faster calibration, maximize the batch size based on available GPU memory and choose the right number of GPUs to avoid unnecessary communication.*
>
## Framework Scripts

### Hugging Face Example [Script](./scripts/huggingface_example.sh)
Expand Down
56 changes: 56 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import glob
import os
import shutil
Expand All @@ -32,11 +33,66 @@
except ImportError:
snapshot_download = None

import modelopt.torch.quantization as mtq
from modelopt.torch.utils.image_processor import MllamaImageProcessor

SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]


def build_quant_cfg(
qformat,
kv_cache_qformat,
awq_block_size,
auto_quantize,
model_type,
quant_cfg_choices,
kv_quant_cfg_choices,
):
quant_cfg = {}
if not auto_quantize:
assert qformat in quant_cfg_choices, (
f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache"
)

quant_cfg = quant_cfg_choices[qformat]

if "awq" in qformat:
quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
# If awq_block_size argument is provided, update weight_quantizer
if awq_block_size:
weight_quantizer["block_sizes"][-1] = awq_block_size

# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}

enable_quant_kv_cache = kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = apply_kv_cache_quant(
quant_cfg,
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
)
Comment on lines +57 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Always deep-copy quant_cfg to prevent shared state mutation.

Non-AWQ formats use a shallow reference (line 57), but apply_kv_cache_quant (lines 77-80) mutates quant_cfg["quant_cfg"]. This pollutes the shared quant_cfg_choices module object across calls, similar to the issue flagged in past reviews.

Apply this diff:

-        quant_cfg = quant_cfg_choices[qformat]
+        quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
 
         if "awq" in qformat:
-            quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
             weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
🤖 Prompt for AI Agents
In examples/llm_ptq/example_utils.py around lines 57 to 80, the code assigns
quant_cfg = quant_cfg_choices[qformat] by reference for non-AWQ formats but
later mutate quant_cfg via apply_kv_cache_quant, which leaks into the shared
quant_cfg_choices; change the initial assignment to always make a deep copy
(e.g., quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])) so both AWQ and
non-AWQ branches operate on an isolated copy before any in-place updates
(preserve the existing extra deepcopy in the AWQ branch if present).


# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
if model_type == "gemma" and "int8_sq" in qformat:
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}

if model_type == "phi4mm":
# Only quantize the language model
quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}

return quant_cfg


def is_speculative(hf_config):
"""Check if the model architecture is a speculative model."""
return hf_config.architectures and any(
Expand Down
30 changes: 30 additions & 0 deletions examples/llm_ptq/fsdp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# =============================================================================
# FSDP Configuration for running LLM PTQ on multinode setup. This file is consumed by examples/llm_ptq/multinode_ptq.py
# =============================================================================

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_use_orig_params: true
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 2
num_processes: 16
rdzv_backend: c10d
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
52 changes: 10 additions & 42 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import argparse
import copy
import random
import time
import warnings
Expand All @@ -25,6 +24,7 @@
from accelerate.hooks import remove_hook_from_module
from example_utils import (
apply_kv_cache_quant,
build_quant_cfg,
copy_custom_model_files,
get_model,
get_processor,
Expand Down Expand Up @@ -448,47 +448,15 @@ def main(args):
include_labels=args.auto_quantize_bits is not None,
)

quant_cfg = {}
if not args.auto_quantize_bits:
assert args.qformat in QUANT_CFG_CHOICES, (
f"Unsupported quantization format: {args.qformat} with {args.kv_cache_qformat} KV cache"
)

quant_cfg = QUANT_CFG_CHOICES[args.qformat]

if "awq" in args.qformat:
quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat])
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
# If awq_block_size argument is provided, update weight_quantizer
if args.awq_block_size:
weight_quantizer["block_sizes"][-1] = args.awq_block_size

# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
if args.qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}

enable_quant_kv_cache = args.kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = apply_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
)

# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
if model_type == "gemma" and "int8_sq" in args.qformat:
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}

if model_type == "phi4mm":
# Only quantize the language model
quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
quant_cfg = build_quant_cfg(
args.qformat,
args.kv_cache_qformat,
args.awq_block_size,
args.auto_quantize_bits,
model_type,
QUANT_CFG_CHOICES,
KV_QUANT_CFG_CHOICES,
)

if not model_is_already_quantized or calibration_only:
# Only run single sample for preview
Expand Down
Loading