diff --git a/CHANGELOG.rst b/CHANGELOG.rst index cdf0b5aa7..d59ec07cd 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,6 +12,7 @@ Model Optimizer Changelog (Linux) - Add support for ``nemotron-post-training-dataset-v2`` and ``nemotron-post-training-dataset-v1`` in ``examples/llm_ptq``. Default to a mix of ``cnn_dailymail`` and ``nemotron-post-training-dataset-v2`` (gated dataset accessed using ``HF_TOKEN`` environment variable) if no dataset is specified. - Allow specifying ``calib_seq`` in ``examples/llm_ptq`` to set the maximum sequence length for calibration. - Add support for MCore MoE PTQ/QAT/QAD. +- Add support for multi-node PTQ and export with FSDP2 in ``examples/llm_ptq/multinode_ptq.py``. See `examples/llm_ptq/README.md `_ for more details. **Documentation** diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 46780b368..1e97f4b44 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -235,6 +235,38 @@ with init_quantized_weights(mtq.NVFP4_DEFAULT_CFG): mtq.calibrate(model, algorithm="max", forward_loop=calibrate_loop) ``` +## Multi-Node Post-Training Quantization with FSDP2 + +ModelOpt enables quantization of LLMs across multiple GPU nodes using various quantization formats. It leverages HuggingFace's Accelerate library and FSDP2 for distributed model sharding and calibration. + +### Usage + +For distributed execution across multiple nodes, use the `accelerate` library. A template configuration file (`fsdp2.yaml`) is provided and can be customized for user specific requirements. + +On each node run the following command: + +```bash +accelerate launch --config_file fsdp2.yaml \ + --num_machines= \ + --machine_rank= \ + --main_process_ip= \ + --main_process_port= \ + --fsdp_transformer_layer_cls_to_wrap= + multinode_ptq.py \ + --pyt_ckpt_path \ + --qformat \ + --kv_cache_qformat \ + --batch_size \ + --calib_size \ + --dataset \ + --export_path \ + --trust_remote_code +``` + +The exported checkpoint can be deployed using TensorRT-LLM/ vLLM/ SGLang. For more details refer to the [deployment section](#deployment) of this document. + +> *Performance Note: FSDP2 is designed for training workloads and may result in longer calibration and export times. For faster calibration, maximize the batch size based on available GPU memory and choose the right number of GPUs to avoid unnecessary communication.* +> ## Framework Scripts ### Hugging Face Example [Script](./scripts/huggingface_example.sh) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 38e11a8e1..d6ae283a1 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import glob import os import shutil @@ -32,11 +33,66 @@ except ImportError: snapshot_download = None +import modelopt.torch.quantization as mtq from modelopt.torch.utils.image_processor import MllamaImageProcessor SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] +def build_quant_cfg( + qformat, + kv_cache_qformat, + awq_block_size, + auto_quantize, + model_type, + quant_cfg_choices, + kv_quant_cfg_choices, +): + quant_cfg = {} + if not auto_quantize: + assert qformat in quant_cfg_choices, ( + f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache" + ) + + quant_cfg = quant_cfg_choices[qformat] + + if "awq" in qformat: + quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) + weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] + if isinstance(weight_quantizer, list): + weight_quantizer = weight_quantizer[0] + # If awq_block_size argument is provided, update weight_quantizer + if awq_block_size: + weight_quantizer["block_sizes"][-1] = awq_block_size + + # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models + if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: + quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} + + enable_quant_kv_cache = kv_cache_qformat != "none" + print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") + + # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. + if enable_quant_kv_cache: + quant_cfg = apply_kv_cache_quant( + quant_cfg, + getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"], + ) + + # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. + if model_type == "gemma" and "int8_sq" in qformat: + quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} + + if model_type == "phi4mm": + # Only quantize the language model + quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} + quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} + quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + + return quant_cfg + + def is_speculative(hf_config): """Check if the model architecture is a speculative model.""" return hf_config.architectures and any( diff --git a/examples/llm_ptq/fsdp2.yaml b/examples/llm_ptq/fsdp2.yaml new file mode 100644 index 000000000..646d63f9e --- /dev/null +++ b/examples/llm_ptq/fsdp2.yaml @@ -0,0 +1,30 @@ +# ============================================================================= +# FSDP Configuration for running LLM PTQ on multinode setup. This file is consumed by examples/llm_ptq/multinode_ptq.py +# ============================================================================= + +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_use_orig_params: true + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 2 +num_processes: 16 +rdzv_backend: c10d +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index c55c38abc..dcd3e0f66 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,7 +14,6 @@ # limitations under the License. import argparse -import copy import random import time import warnings @@ -25,6 +24,7 @@ from accelerate.hooks import remove_hook_from_module from example_utils import ( apply_kv_cache_quant, + build_quant_cfg, copy_custom_model_files, get_model, get_processor, @@ -448,47 +448,15 @@ def main(args): include_labels=args.auto_quantize_bits is not None, ) - quant_cfg = {} - if not args.auto_quantize_bits: - assert args.qformat in QUANT_CFG_CHOICES, ( - f"Unsupported quantization format: {args.qformat} with {args.kv_cache_qformat} KV cache" - ) - - quant_cfg = QUANT_CFG_CHOICES[args.qformat] - - if "awq" in args.qformat: - quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat]) - weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] - if isinstance(weight_quantizer, list): - weight_quantizer = weight_quantizer[0] - # If awq_block_size argument is provided, update weight_quantizer - if args.awq_block_size: - weight_quantizer["block_sizes"][-1] = args.awq_block_size - - # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models - if args.qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: - quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} - - enable_quant_kv_cache = args.kv_cache_qformat != "none" - print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") - - # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. - if enable_quant_kv_cache: - quant_cfg = apply_kv_cache_quant( - quant_cfg, - getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], - ) - - # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. - if model_type == "gemma" and "int8_sq" in args.qformat: - quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} - - if model_type == "phi4mm": - # Only quantize the language model - quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} - quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} - quant_cfg["quant_cfg"]["*image*"] = {"enable": False} - quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + quant_cfg = build_quant_cfg( + args.qformat, + args.kv_cache_qformat, + args.awq_block_size, + args.auto_quantize_bits, + model_type, + QUANT_CFG_CHOICES, + KV_QUANT_CFG_CHOICES, + ) if not model_is_already_quantized or calibration_only: # Only run single sample for preview diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py new file mode 100644 index 000000000..f3bd4bd59 --- /dev/null +++ b/examples/llm_ptq/multinode_ptq.py @@ -0,0 +1,367 @@ +"""Multi-node PTQ (Post-Training Quantization) with FSDP2 support.""" + +import argparse +import json +import os +import random +import time +import warnings +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from accelerate import Accelerator +from example_utils import build_quant_cfg, get_tokenizer +from tqdm import tqdm +from transformers import AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedTokenizerFast + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +from modelopt.torch.export import get_model_type +from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format +from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint +from modelopt.torch.quantization.config import need_calibration +from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes +from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets + +# Constants +RAND_SEED = 1234 + +QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { + "int8": mtq.INT8_DEFAULT_CFG, + "int4_awq": mtq.INT4_AWQ_CFG, + "fp8": mtq.FP8_DEFAULT_CFG, + "nvfp4": mtq.NVFP4_DEFAULT_CFG, + "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, + "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, + "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, +} + +KV_QUANT_CFG_CHOICES = { + "none": "none", + "fp8": "FP8_KV_CFG", + "nvfp4": "NVFP4_KV_CFG", + "nvfp4_affine": "NVFP4_AFFINE_KV_CFG", +} + + +# Enable HuggingFace checkpointing +mto.enable_huggingface_checkpointing() + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Multi-node post-training quantization with FSDP2") + + parser.add_argument( + "--pyt_ckpt_path", + required=True, + help="Path to PyTorch checkpoint", + ) + parser.add_argument( + "--qformat", + default="fp8", + choices=QUANT_CFG_CHOICES.keys(), + help="Quantization format", + ) + parser.add_argument( + "--kv_cache_qformat", + default="fp8", + choices=list(KV_QUANT_CFG_CHOICES.keys()), + help="KV cache quantization format", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size for calibration", + ) + parser.add_argument( + "--calib_size", + type=str, + default="512", + help="Comma-separated list of calibration sizes per dataset", + ) + parser.add_argument( + "--dataset", + help=( + f"name of a dataset, or a comma separated list of datasets. " + f"dataset choices are {get_supported_datasets()}" + ), + type=str, + default=None, + ) + parser.add_argument( + "--export_path", + default="exported_model", + help="Directory to export the quantized model", + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Trust remote code for HuggingFace models", + ) + parser.add_argument("--awq_block_size", default=0, type=int) + + args = parser.parse_args() + + # Parse comma-separated lists + args.dataset = args.dataset.split(",") if args.dataset else None + args.calib_size = [int(x) for x in args.calib_size.split(",")] + + return args + + +def load_and_prepare_model( + model_path: str, + calib_dataloader: torch.utils.data.DataLoader, + accelerator: Accelerator, + trust_remote_code: bool = False, +) -> tuple[nn.Module, str, list[str], torch.utils.data.DataLoader]: + """Load model and prepare it for FSDP2 distributed execution. + + Args: + model_path: Path to the HuggingFace model + calibration_dataloader: Calibration dataloader to be sharded for calibration + accelerator: Accelerate's Accelerator instance + trust_remote_code: Whether to trust remote code + + Returns: + Tuple of (prepared_model, model_type, original_architectures, calibration_dataloader) + """ + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + trust_remote_code=trust_remote_code, + ) + model.eval() + model_type = get_model_type(model) + # Need the original architectures for export + # FSDP prefix is added to the architectures for FSDP2 wrapped models + original_architectures = model.config.architectures + + # FSDP2 requires an optimizer to be prepared together with the model + dummy_optimizer = torch.optim.SGD(model.parameters(), lr=0.0) + model, _, calibration_dataloader = accelerator.prepare(model, dummy_optimizer, calib_dataloader) + + return model, model_type, original_architectures, calibration_dataloader + + +def create_calibration_dataloader( + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + dataset_names: list[str], + calib_sizes: list[int], + batch_size: int, +) -> torch.utils.data.DataLoader: + """Create calibration dataloader from dataset. + + Args: + tokenizer: HuggingFace tokenizer + dataset_names: List of dataset names (defaults to cnn_dailymail) + calib_sizes: Number of samples for each dataset + batch_size: Batch size for calibration + + Returns: + DataLoader for calibration + """ + + return get_dataset_dataloader( + dataset_name=dataset_names, + tokenizer=tokenizer, + batch_size=batch_size, + num_samples=calib_sizes, + device=None, # Keep data on CPU, calibration loop handles device transfer + include_labels=False, + ) + + +def create_fsdp2_calibration_loop( + model: nn.Module, + dataloader: torch.utils.data.DataLoader, + accelerator: Accelerator, +): + """Create calibration loop compatible with FSDP2. + + For FSDP2, we need to use the outer FSDP-wrapped model instead of + the parameter passed by mtq.quantize to properly handle DTensor. + + Args: + model: FSDP2-wrapped model + dataloader: Calibration dataloader + accelerator: Accelerator instance for device management + + Returns: + Calibration function compatible with mtq.quantize + """ + + def calibrate(unwrapped_model): + """Calibration loop that uses the FSDP-wrapped model.""" + for batch in tqdm(dataloader, desc="Calibrating"): + if isinstance(batch, dict): + batch = { + k: v.to(accelerator.device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } + # Use outer model (FSDP-wrapped), not the parameter + # Important: We should forward pass using the unwrapped model + # mtq.quantize will unwrap the model & pass to the forward_loop + model(**batch) + + return calibrate + + +def export_model( + model: nn.Module, + accelerator: Accelerator, + export_path: str | Path, + architectures: list[str], +): + """Export quantized model to HuggingFace format. + + Args: + model: Quantized model + accelerator: Accelerator instance for state dict gathering + export_path: Directory to export model to + """ + export_dir = Path(export_path) + export_dir.mkdir(parents=True, exist_ok=True) + + post_state_dict, hf_quant_config = _export_hf_checkpoint( + model, torch.bfloat16, accelerator=accelerator + ) + + if accelerator.is_main_process: + # Save hf_quant_config.json for backward compatibility + with open(f"{export_dir}/hf_quant_config.json", "w") as file: + json.dump(hf_quant_config, file, indent=4) + + hf_quant_config = convert_hf_quant_config_format(hf_quant_config) + + # Save model + model.save_pretrained(export_dir, state_dict=post_state_dict, save_modelopt_state=False) + + original_config = f"{export_dir}/config.json" + config_data = {} + + with open(original_config) as file: + config_data = json.load(file) + + config_data["quantization_config"] = hf_quant_config + # Update config architectures to use original architectures that does not have FSDP prefix + config_data["architectures"] = architectures + + with open(original_config, "w") as file: + json.dump(config_data, file, indent=4) + + +def main(args): + """Main quantization workflow.""" + # Validate GPU availability + if not torch.cuda.is_available(): + raise OSError("GPU is required for quantization.") + + # Validate quantization format + if args.qformat not in QUANT_CFG_CHOICES: + raise ValueError( + f"Quantization format {args.qformat} not supported. Choose from: {QUANT_CFG_CHOICES.keys()}" + ) + + # Set random seeds + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + torch.manual_seed(RAND_SEED) + + # Initialize accelerator + accelerator = Accelerator() + + print(f"Rank: {os.environ.get('RANK', 'Not set')}") + print(f"World Size: {os.environ.get('WORLD_SIZE', 'Not set')}") + print(f"Local Rank: {os.environ.get('LOCAL_RANK', 'Not set')}") + + # Load tokenizer + tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) + default_padding_side = tokenizer.padding_side + tokenizer.padding_side = "left" # Left padding for better calibration + + # Set default dataset if not provided + if args.dataset is None: + args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] + warnings.warn( + "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2." + ) + # Adjust calib_size to match dataset length by extending or truncating as needed + args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ + : len(args.dataset) + ] + + # Create calibration dataloader with max batch size + calib_dataloader = create_calibration_dataloader( + tokenizer=tokenizer, + dataset_names=args.dataset, + calib_sizes=args.calib_size, + batch_size=args.batch_size, + ) + + # Load and prepare model + model, model_type, original_architectures, calib_dataloader = load_and_prepare_model( + model_path=args.pyt_ckpt_path, + calib_dataloader=calib_dataloader, + accelerator=accelerator, + trust_remote_code=args.trust_remote_code, + ) + + # Build quantization config + quant_cfg = build_quant_cfg( + args.qformat, + args.kv_cache_qformat, + args.awq_block_size, + None, + model_type, + QUANT_CFG_CHOICES, + KV_QUANT_CFG_CHOICES, + ) + + # Quantize the model + if accelerator.is_main_process: + print("Starting quantization...") + + start_time = time.time() + + if need_calibration(quant_cfg): + calibrate_fn = create_fsdp2_calibration_loop(model, calib_dataloader, accelerator) + else: + calibrate_fn = None + warnings.warn("Dynamic quantization. Calibration skipped.") + + with torch.no_grad(): + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_fn) + + elapsed = time.time() - start_time + + if accelerator.is_main_process: + print(f"Quantization completed in {elapsed:.2f}s") + mtq.print_quant_summary(model) + + start_time = time.time() + export_model(model, accelerator, args.export_path, original_architectures) + elapsed = time.time() - start_time + + if accelerator.is_main_process: + # Restore default padding and export the tokenizer as well. + if tokenizer is not None: + tokenizer.padding_side = default_padding_side + tokenizer.save_pretrained(args.export_path) + # Export the model + print(f"Export completed in {elapsed:.2f}s") + print(f"Model exported to {args.export_path}") + + print("Unpatching FSDP2 MP dtypes") + + +if __name__ == "__main__": + args = parse_args() + # This context manager can be removed once the update to FSDP2 function is reflected in torch + with patch_fsdp_mp_dtypes(): + main(args) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f966ffac6..954707e8d 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -27,11 +27,12 @@ import torch import torch.nn as nn from safetensors.torch import save_file +from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.qtensor import NVFP4QTensor -from modelopt.torch.quantization.utils import quantizer_attr_names +from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names from .convert_hf_config import convert_hf_quant_config_format from .layer_utils import ( @@ -114,7 +115,8 @@ def _output_hook(module, input, output): # update_experts_avg_prequant_scale(module) grouped_experts = get_experts_list(module, model_type) for modules in grouped_experts: - preprocess_linear_fusion(modules, resmooth_only=True) + with fsdp2_aware_weight_update(model, modules): + preprocess_linear_fusion(modules, resmooth_only=True) # Attach hook to layernorm modules that need to be fused if is_layernorm(module): @@ -161,7 +163,8 @@ def _output_hook(module, input, output): QUANTIZATION_FP8_PB_REAL, ]: # Fuse modules that have the same input - preprocess_linear_fusion(modules) + with fsdp2_aware_weight_update(model, modules): + preprocess_linear_fusion(modules) fused_linears[modules[0].name] = [module.name for module in modules] # Fuse layernorms @@ -171,7 +174,8 @@ def _output_hook(module, input, output): and tensor in output_to_layernorm ): # Pre quant scale of modules is already updated to avg_pre_quant_scale - fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): + fuse_prequant_layernorm(output_to_layernorm[tensor], modules) # The dummy forward may not be able to activate all the experts. # Process experts by naming rules like experts.0, experts.1, etc. @@ -192,7 +196,8 @@ def _output_hook(module, input, output): assert new_expert_name in module_names new_expert_modules.append(model.get_submodule(new_expert_name)) - preprocess_linear_fusion(new_expert_modules) + with fsdp2_aware_weight_update(model, new_expert_modules): + preprocess_linear_fusion(new_expert_modules) expert_id += 1 @@ -339,7 +344,9 @@ def _export_quantized_weight( def _export_hf_checkpoint( - model: nn.Module, dtype: torch.dtype | None = None + model: nn.Module, + dtype: torch.dtype | None = None, + **kwargs, ) -> tuple[dict[str, Any], dict[str, Any]]: """Exports the torch model to the packed checkpoint with original HF naming. @@ -348,6 +355,7 @@ def _export_hf_checkpoint( Args: model: the torch model. dtype: the weights data type to export the unquantized layers or the default model data type if None. + accelerator: the accelerator instance in case of distributed export setup. Returns: post_state_dict: Dict containing quantized weights @@ -361,6 +369,8 @@ def _export_hf_checkpoint( f"({dtype}), which may lead to numerical errors." ) + accelerator = kwargs.get("accelerator") + # Create a model layer pool # If `model.model` exists use that, otherwise use `model` itself, e.g., Nemotron-H root = getattr(model, "model", model) @@ -458,12 +468,24 @@ def _export_hf_checkpoint( # Track if any layers are quantized to properly set exclude_modules has_quantized_layers = False + fsdp_module_to_reshard = None for name, sub_module in layer_pool.items(): + # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead + if isinstance(sub_module, FSDPModule): + # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. + # We need to reshard the previous FSDPModule to prevent potential OOM. + # This hack reduces the number of unshard reshard operations, to avoid unnecessary communication. + if fsdp_module_to_reshard is not None: + fsdp_module_to_reshard.reshard() + + fsdp_module_to_reshard = sub_module + if get_quantization_format(sub_module) != QUANTIZATION_NONE: has_quantized_layers = True if is_quantlinear(sub_module): - _export_quantized_weight(sub_module, dtype) + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_quantized_weight(sub_module, dtype) elif ( "Llama4TextExperts" in type(sub_module).__name__ or "GptOssExperts" in type(sub_module).__name__ @@ -481,9 +503,14 @@ def _export_hf_checkpoint( ) # Export the quantized weights for weight_name in ["gate_up_proj", "down_proj"]: - _export_quantized_weight(sub_module, dtype, weight_name) + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_quantized_weight(sub_module, dtype, weight_name) - quantized_state_dict = model.state_dict() + if accelerator is not None: + # Gather state_dict from all ranks + quantized_state_dict = accelerator.get_state_dict(model) + else: + quantized_state_dict = model.state_dict() quantized_state_dict = postprocess_state_dict( quantized_state_dict, kv_cache_max_bound, kv_cache_format diff --git a/modelopt/torch/quantization/qtensor/base_qtensor.py b/modelopt/torch/quantization/qtensor/base_qtensor.py index 1987428c9..d5a9a4269 100644 --- a/modelopt/torch/quantization/qtensor/base_qtensor.py +++ b/modelopt/torch/quantization/qtensor/base_qtensor.py @@ -16,14 +16,13 @@ """Base Class for Real Quantized Tensor.""" import enum -import warnings -from contextlib import contextmanager import torch -from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam from torch.distributed.tensor import DTensor +from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, patch_fsdp_mp_dtypes + class QTensorType(enum.Enum): """Enumeration for defining types of quantization.""" @@ -194,62 +193,6 @@ def custom_load_from_state_dict(self, state_dict, prefix, *args, **kwargs): module._load_from_state_dict = custom_load_from_state_dict.__get__(module, type(module)) -def get_prefixed_param_names(parent_model, target_module): - """Get parameter names for a target module prefixed with the parent model name. - - This function is used to get full parameter name from FSDPParam module_info which stores the - unprefixed parameter name. - - """ - target_ids = {id(p) for p in target_module.parameters()} - return next( - ( - name.rsplit(".", 1)[0] - for name, param in parent_model.named_parameters() - if id(param) in target_ids - ), - None, # default value if no match - ) - - -@contextmanager -def no_requires_grad(): - """Context manager to temporarily set requires_grad to False. - - This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates - a new parameter with default requires_grad and then update the requires_grad attribute as needed. This - triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True - for integer tensors. - """ - original_new = torch.nn.Parameter.__new__ - - def patched_new(cls, data=None, requires_grad=True): - return original_new(cls, data, requires_grad=False) - - torch.nn.Parameter.__new__ = patched_new - try: - yield - finally: - torch.nn.Parameter.__new__ = original_new - - -@contextmanager -def enable_fake_quant(module): - """Temporarily set the fake_quant attribute of a module to True. - - This is used to prevent weight compression from being triggered during an unshard() call. - """ - original_fake_quant = [] - for m in module.modules(): - if hasattr(m, "weight_quantizer"): - original_fake_quant.append(m.weight_quantizer._fake_quant) - m.weight_quantizer._fake_quant = True - yield - for m in module.modules(): - if hasattr(m, "weight_quantizer"): - m.weight_quantizer._fake_quant = original_fake_quant.pop(0) - - def pack_real_quantize_weight(module, force_quantize: bool = False): """Pack real quantized tensors to a compressed format and set proper load_state_dict function.""" # Import SequentialQuantizer here to avoid circular import @@ -275,96 +218,12 @@ def _compress_and_update_module_weight(module): return False - def _create_fsdp_param_mapping(fsdp_param_list, model): - """Builds a mapping from module name to their corresponding FSDPParam. - - Args: - fsdp_param_list (list): List of FSDPParam. - model (nn.Module): FSDP root module. - - Returns: - dict: Full parameter name → FSDP parameter. - """ - return { - get_prefixed_param_names(model, param._module_info.module): param - for param in fsdp_param_list - } - - def _compress_fsdp_module(fsdp_module): - """Applies weight compression to an FSDP-wrapped module and updates its sharded parameter group. - - This function unshards the FSDP module to access full weights and compresses each eligible submodule’s weights. - A new FSDPParam wrapped with `QFSDPParam` is registered to the FSDPParamGroup for future handling of - sharding and unsharding. The weight_scale buffers registered during compression and the FSDPModule are reharded - once compression is complete. - - Args: - fsdp_module (nn.Module): The FSDP-wrapped module to compress. - - Returns: - None - """ - # Unshard FSDPmodule by temporarily setting _fake_quant to prevent weight compression from being triggered - with enable_fake_quant(fsdp_module): - fsdp_module.unshard() - - # Get the FSDPParamGroup for the FSDPModule - fsdp_param_group = fully_shard.state(fsdp_module)._fsdp_param_group - - if getattr(fsdp_param_group, "fsdp_params", None) is None: - warnings.warn( - f"FSDPParamGroup for {fsdp_module} has no fsdp_params, skipping compression" - ) - return - - # Create FSDPParam mapping dictionary to keep track of FSDPParams to update/delete - fsdp_param_mapping = _create_fsdp_param_mapping(fsdp_param_group.fsdp_params, fsdp_module) - - for name, submodule in fsdp_module.named_modules(): - # This is to handle case where the root FSDPModule has parameters. - # We skip all the parameters that dont belong to the FSDPParamGroup. - if name not in fsdp_param_mapping: - continue - - if _compress_and_update_module_weight(submodule): - old_fsdp_param = fsdp_param_mapping[name] - - # Update mp policy to reflect the new dtype - new_mp_policy = MixedPrecisionPolicy( - param_dtype=submodule.weight.dtype, - reduce_dtype=None, - output_dtype=None, - cast_forward_inputs=False, - ) - with no_requires_grad(): - # Create a new QFSDPParam parameter - new_param = QFSDPParam( - submodule.weight, - old_fsdp_param._module_info, - old_fsdp_param.mesh_info, - old_fsdp_param.post_forward_mesh_info, - old_fsdp_param.device, - None, - new_mp_policy, - None, - ) - - # Update the FSDPParam mapping to keep track of the new FSDPParam - fsdp_param_mapping[name] = new_param - # Remove the post_load_hook_handle to allow gc to collect the old FSDPParam - old_fsdp_param._post_load_hook_handle.remove() - - # Update FSDPParam list with new compressed weights - fsdp_param_group.fsdp_params = list(fsdp_param_mapping.values()) - - # Reshard FSDP root module - fsdp_module.reshard() - - with SequentialQuantizer.convert_to_single_quantizer(module), torch.no_grad(): - for _, m in module.named_modules(): - # If FSDP module, we need to additionally process the FSDPParam list - if isinstance(m, FSDPModule): - _compress_fsdp_module(m) - else: - # Compress weights and update module weight - _compress_and_update_module_weight(m) + with ( + SequentialQuantizer.convert_to_single_quantizer(module), + torch.no_grad(), + patch_fsdp_mp_dtypes(), + ): + for name, m in module.named_modules(): + if name != "": + with fsdp2_aware_weight_update(module, m): + _compress_and_update_module_weight(m) diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index 65861695f..2ff1b17e9 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -81,7 +81,9 @@ def get_weights_scaling_factor( # Get per block amax per_block_amax = reduce_block_amax(input, block_sizes={-1: block_size}).float() # Get per-block-scale - per_block_scale = per_block_amax / (6.0 * weights_scaling_factor_2) + per_block_scale = per_block_amax / ( + 6.0 * weights_scaling_factor_2.to(per_block_amax.device) + ) # Set all zero values in scale to 1.0 per_block_scale[per_block_scale == 0] = 1.0 # Convert to torch.float8_e4m3fn diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 43e269fa1..22132d598 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -15,18 +15,24 @@ """Quantization utilities.""" +from __future__ import annotations + from collections import namedtuple -from collections.abc import Generator from contextlib import ExitStack, contextmanager, nullcontext +from typing import TYPE_CHECKING import torch import torch.nn as nn import torch.nn.functional as F -from torch.distributed.fsdp import FSDPModule +from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard +from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam from torch.distributed.tensor import Replicate from modelopt.torch.utils import get_unwrapped_name, print_rank_0 +if TYPE_CHECKING: + from collections.abc import Generator + __all__ = [ "EXPORT_MODE", "convert_quantization_axis_to_reduce_axis", @@ -357,13 +363,19 @@ def _get_fsdp2_mesh(module: nn.Module): return fsdp_state._fsdp_param_group.post_forward_mesh_info.mesh +def _get_module_name(module: nn.Module, root_model: nn.Module): + name_to_module = dict(root_model.named_modules()) + target_module_name = next((name for name, m in name_to_module.items() if m is module), None) + return target_module_name + + def _get_enclosing_fsdp_module(module: nn.Module, root_model: nn.Module): """Get the enclosing FSDP module for a given module.""" if isinstance(module, FSDPModule): return module name_to_module = dict(root_model.named_modules()) - target_module_name = next((name for name, m in name_to_module.items() if m is module), None) + target_module_name = _get_module_name(module, root_model) if target_module_name is None: raise ValueError(f"Module {module} not found in the root model {root_model}.") @@ -467,3 +479,227 @@ def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict): key = get_unwrapped_name(name, model) if isinstance(module, TensorQuantizer) and key in quantizer_state_dict: module.load_state_dict(quantizer_state_dict[key]) + + +@contextmanager +def patch_fsdp_mp_dtypes(): + """Patch FSDP2 to handle mixed dtypes properly during quantization. + + This patch is used to relax the requirement of uniform original parameter dtype in FSDP2 and is + copied from the latest torch FSDP repository `torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py `_. + """ + + def _init_mp_dtypes(self) -> None: + """This function is directly copied from the latest version of torch FSDP.""" + for fsdp_param in self.fsdp_params: + fsdp_param.init_dtype_attrs(self.mp_policy) + + trainable_params: list[FSDPParam] = [ + p for p in self.fsdp_params if p.sharded_param.requires_grad + ] + orig_dtypes = {p.orig_dtype for p in trainable_params} + reduce_dtypes = {p.reduce_dtype for p in trainable_params} + + if len(trainable_params) > 0 and len(orig_dtypes) != 1: + raise AssertionError( + f"FSDP expects uniform original parameter dtype but got {orig_dtypes}" + ) + + self._orig_dtype = next(iter(orig_dtypes)) if len(trainable_params) else None + + if len(trainable_params) > 0 and len(reduce_dtypes) != 1: + raise AssertionError(f"FSDP expects uniform reduce dtype but got {reduce_dtypes}") + + self._reduce_dtype = next(iter(reduce_dtypes)) if len(trainable_params) else None + + # Apply the patch + original_init_mp_dtypes = ( + torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes + ) + try: + torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes = ( + _init_mp_dtypes + ) + yield + finally: + torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup._init_mp_dtypes = ( + original_init_mp_dtypes + ) + + +def get_prefixed_param_names(parent_model, target_module): + """Get parameter names for a target module prefixed with the parent model name. + + This function is used to get full parameter name from FSDPParam module_info which stores the + unprefixed parameter name. + + """ + target_ids = {id(p) for p in target_module.parameters()} + return next( + ( + name.rsplit(".", 1)[0] + for name, param in parent_model.named_parameters() + if id(param) in target_ids + ), + None, # default value if no match + ) + + +def create_fsdp_param_mapping(fsdp_param_list, model): + """Builds a mapping from module name to their corresponding FSDPParam. + + Args: + fsdp_param_list (list): List of FSDPParam. + model (nn.Module): FSDP root module. + + Returns: + dict: Full parameter name → FSDP parameter. + """ + return { + get_prefixed_param_names(model, param._module_info.module): param + for param in fsdp_param_list + } + + +@contextmanager +def no_requires_grad(): + """Context manager to temporarily set requires_grad to False. + + This is used to allow us to call init_sharded_parameter() on the compressed weights. Currently FSDP2 creates + a new parameter with default requires_grad and then update the requires_grad attribute as needed. This + triggers an error when torch.nn.Parameter is called on compressed weights as requires_grad cannot be set to True + for integer tensors. + """ + original_new = torch.nn.Parameter.__new__ + + def patched_new(cls, data=None, requires_grad=True): + return original_new(cls, data, requires_grad=False) + + torch.nn.Parameter.__new__ = patched_new + try: + yield + finally: + torch.nn.Parameter.__new__ = original_new + + +@contextmanager +def enable_fake_quant(module): + """Temporarily set the fake_quant attribute of a module to True. + + This is used to prevent weight compression from being triggered during an unshard() call. + """ + original_fake_quant = [] + for m in module.modules(): + if hasattr(m, "weight_quantizer"): + original_fake_quant.append(m.weight_quantizer._fake_quant) + m.weight_quantizer._fake_quant = True + yield + for m in module.modules(): + if hasattr(m, "weight_quantizer"): + m.weight_quantizer._fake_quant = original_fake_quant.pop(0) + + +@contextmanager +def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True): + """Context manager to update the FSDPParam list if an update is made to a submodule of an FSDPModule. + + This context manager is to be used when updating a weight of a sharded module to ensure the changes are properly + reflected for future unsharding and resharding the FSDP root module. The context manager will unshard the FSDP root + module, register new FSDPParam/QFSDPParam for the updated modules and updates the FSDP param group list. + + If reshard is True, the context manager will also reshard the FSDP root module after the weight update. + + Args: + root_model (nn.Module): The root model of the FSDPModule. + modules_to_update (list): The list of modules to update which should be a list of modules that are + direct children of the FSDPModule. + reshard (bool): Whether to reshard the FSDP root module after the weight update. + + Returns: + None + """ + try: + if isinstance(root_model, FSDPModule): + # Get FSDP root module, if none is returned, then the update is not made to a submodule of an FSDPModule + if not isinstance(modules_to_update, list): + modules_to_update = [modules_to_update] + + root_modules = set() + for module in modules_to_update: + root_module = _get_enclosing_fsdp_module(module, root_model) + root_modules.add(root_module) + + # Ensure all modules in root_modules are the same + assert len(root_modules) == 1, "All modules must be in the same root FSDPModule" + root_module = next(iter(root_modules)) + + # Check if root module state is sharded and unshard if needed + if fully_shard.state(root_module)._fsdp_param_group.is_sharded: + with enable_fake_quant(root_module): + root_module.unshard() + + # Get FSDPParam list + fsdp_param_group = fully_shard.state(root_module)._fsdp_param_group + fsdp_param_mapping = create_fsdp_param_mapping(fsdp_param_group.fsdp_params, root_model) + + # Assert that all the modules in the module list are present in this fsdp_param_group + if len(modules_to_update) > 1: + for module in modules_to_update: + name = _get_module_name(module, root_model) + assert name in fsdp_param_mapping, ( + f"Module {module} not found in fsdp_param_mapping" + ) + # Yields for necessary weight updates/processing + yield + finally: + from modelopt.torch.quantization.qtensor.base_qtensor import QFSDPParam, QTensorWrapper + + if isinstance(root_model, FSDPModule): + # Update FSDPParam list + for module in modules_to_update: + name = _get_module_name(module, root_model) + if name not in fsdp_param_mapping: + continue + + old_fsdp_param = fsdp_param_mapping[name] + + # Update mp policy to reflect the new dtype + new_mp_policy = MixedPrecisionPolicy( + param_dtype=module.weight.dtype, + reduce_dtype=None, + output_dtype=None, + cast_forward_inputs=False, + ) + + with no_requires_grad(): + # Create a new QFSDPParam or FSDPParam based on weight type + param_class = ( + QFSDPParam if isinstance(module.weight, QTensorWrapper) else FSDPParam + ) + + new_param = param_class( + module.weight, + old_fsdp_param._module_info, + old_fsdp_param.mesh_info, + old_fsdp_param.post_forward_mesh_info, + old_fsdp_param.device, + None, + new_mp_policy, + None, + ) + if not isinstance(new_param, QFSDPParam): + new_param.init_dtype_attrs(new_mp_policy) + + # Update the FSDPParam mapping to keep track of the new FSDPParam + fsdp_param_mapping[name] = new_param + + # Remove the post_load_hook_handle to allow gc to collect the old FSDPParam + old_fsdp_param._post_load_hook_handle.remove() + + # Update FSDPParam list with new compressed weights + fsdp_param_group.fsdp_params = list(fsdp_param_mapping.values()) + + # Reshard FSDP root module + if reshard: + with enable_fake_quant(root_module): + root_module.reshard() diff --git a/tests/_test_utils/torch_export/export_utils.py b/tests/_test_utils/torch_export/export_utils.py index e5cd6b8a8..8d2d88608 100644 --- a/tests/_test_utils/torch_export/export_utils.py +++ b/tests/_test_utils/torch_export/export_utils.py @@ -18,20 +18,22 @@ # Models class ToyModel(torch.nn.Module): - def __init__(self, dims=[10, 10, 10, 10]): + def __init__(self, dims=[10, 10, 10, 10], bias=True): super().__init__() assert len(dims) >= 2 if len(dims) == 2: - self.linears = torch.nn.Linear(dims[0], dims[1]) + self.linears = torch.nn.Linear(dims[0], dims[1], bias=bias) else: - linears = [torch.nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)] + linears = [ + torch.nn.Linear(dims[i], dims[i + 1], bias=bias) for i in range(len(dims) - 1) + ] self.linears = torch.nn.Sequential(*linears) def forward(self, x): return self.linears(x) -class SmallQKVModel(torch.nn.Module): +class SmallLinearModelwithCustomWeight(torch.nn.Module): def __init__(self, weights): super().__init__() self.q_proj = torch.nn.Linear(weights[0].shape[1], weights[0].shape[0], bias=False) @@ -52,6 +54,35 @@ def forward(self, x): return x +class SmallQKVModel(torch.nn.Module): + def __init__(self, dim=4, device="cuda", apply_embed=False): + super().__init__() + self.embedding = torch.nn.Embedding(2, dim) + self.q_proj = torch.nn.Linear(dim, dim, bias=False) + self.k_proj = torch.nn.Linear(dim, dim, bias=False) + self.v_proj = torch.nn.Linear(dim, dim, bias=False) + self.o_proj = torch.nn.Linear(dim, dim, bias=False) + self.device = device + self.config = None + self.apply_embed = apply_embed + # TODO: Debug why fsdp2 modifies bias of layernorm for awq + self.input_layernorm = torch.nn.LayerNorm(dim, bias=False) + + def forward(self, x): + if self.apply_embed: + x = self.embedding(x) + + x = self.input_layernorm(x) + q_proj = self.q_proj(x) + k_proj = self.k_proj(x) + v_proj = self.v_proj(x) + scores = torch.matmul(q_proj, k_proj.transpose(-2, -1)) + attn = torch.nn.functional.softmax(scores, dim=-1) + x = torch.matmul(attn, v_proj) + o_proj = self.o_proj(x) + return o_proj + + # Quantization configs partial_fp8_config = { "quant_cfg": { diff --git a/tests/gpu/torch/export/test_export.py b/tests/gpu/torch/export/test_export.py index 36d155155..7c840ff0d 100644 --- a/tests/gpu/torch/export/test_export.py +++ b/tests/gpu/torch/export/test_export.py @@ -16,7 +16,7 @@ import pytest import torch from _test_utils.torch_export.export_utils import ( - SmallQKVModel, + SmallLinearModelwithCustomWeight, ToyModel, only_input_quantizer_fp8_config, only_output_quantizer_fp8_config, @@ -306,7 +306,7 @@ def test_adjust_attn_amax_values( q_weight, k_weight, v_weight, o_weight, expected_qkv_amax, expected_o_amax, config ): # Initialize model and quantize to insert quantizers - model = SmallQKVModel([q_weight, k_weight, v_weight, o_weight]).to("cuda") + model = SmallLinearModelwithCustomWeight([q_weight, k_weight, v_weight, o_weight]).to("cuda") mtq.quantize(model, config, lambda x: x(torch.randn(1, 4, q_weight.shape[1], device="cuda"))) adjust_attn_amax_values(model) # Weight quantizer amax must remain unchanged for non qkv layers @@ -375,11 +375,12 @@ def test_get_scaling_factor( q_weight, k_weight, v_weight, o_weight, config, expected_amax, maxbound ): # Initialize model and quantize to insert quantizers - model = SmallQKVModel([q_weight, k_weight, v_weight, o_weight]).to("cuda") + model = SmallLinearModelwithCustomWeight([q_weight, k_weight, v_weight, o_weight]).to("cuda") mtq.quantize(model, config, lambda x: x(torch.ones(1, 2, q_weight.shape[1], device="cuda"))) for name, module in model.named_modules(): if isinstance(module, TensorQuantizer) and module.is_enabled: scale = get_scaling_factor(module) + print(f"DEBUG LOG: Scale: {scale}, Expected: {expected_amax[0] / maxbound}") assert torch.allclose( scale, torch.tensor((expected_amax[0] / maxbound), dtype=scale.dtype), diff --git a/tests/gpu/torch/export/test_fsdp2_export.py b/tests/gpu/torch/export/test_fsdp2_export.py new file mode 100644 index 000000000..18bcf436b --- /dev/null +++ b/tests/gpu/torch/export/test_fsdp2_export.py @@ -0,0 +1,278 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import copy +from functools import partial + +import pytest +import torch +from _test_utils.torch_dist.dist_utils import get_device_counts, spawn_multiprocess_job +from _test_utils.torch_export.export_utils import SmallQKVModel, ToyModel +from torch.distributed._composable.fsdp import fully_shard + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.layer_utils import is_quantlinear +from modelopt.torch.export.unified_export_hf import ( + _export_quantized_weight, + requantize_resmooth_fused_llm_layers, +) +from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, patch_fsdp_mp_dtypes + + +def _update_weight_test(rank, size): + """Test fsdp2 weight update context for weight update -> only value changed""" + from torch.distributed._composable.fsdp import fully_shard + + with patch_fsdp_mp_dtypes(): + # Define and shard model + model = ToyModel(dims=[4, 4], bias=False).to("cuda") + + assert not torch.equal( + model.linears.weight.data, + torch.zeros(4, 4).to(model.linears.weight.device).to(model.linears.weight.dtype), + ) + + fully_shard(model.linears) + fully_shard(model) + + torch.distributed.barrier() + + for name, module in model.named_modules(): + if "linears" in name: + with fsdp2_aware_weight_update(model, module): + module.weight.data = torch.zeros_like(module.weight.data) + + torch.distributed.barrier() + model.linears.unshard() + + # Check if weights are as expected after unshard + for param in model.parameters(): + assert torch.allclose( + torch.zeros(4, 4).to(param.data.device).to(param.data.dtype), param.data + ) + + # Check if forward pass is as expected + model.linears.reshard() + output = model(torch.randn(4, 4).to(model.linears.weight.device)) + assert torch.allclose(torch.zeros(4, 4).to(output.device).to(output.dtype), output) + + +def _compress_weight_test(rank, size): + """Test fsdp2 weight update context for weight compression -> only value,shape and dtype changed""" + from torch.distributed._composable.fsdp import fully_shard + + with patch_fsdp_mp_dtypes(): + # Define and shard model + model = ToyModel(dims=[6, 6], bias=False).to("cuda") + + assert not torch.equal( + model.linears.weight.data, + torch.zeros(6, 6).to(model.linears.weight.device).to(model.linears.weight.dtype), + ) + + fully_shard(model.linears) + fully_shard(model) + torch.distributed.barrier() + + for name, module in model.named_modules(): + if "linears" in name: + with fsdp2_aware_weight_update(model, module): + module.weight.data = ( + torch.zeros(2, 2).to(torch.float8_e4m3fn).to(module.weight.data.device) + ) + + torch.distributed.barrier() + model.linears.unshard() + # Check if weights are as expected after unshard + for param in model.parameters(): + assert param.data.dtype == torch.float8_e4m3fn + + +def _compare_parameters_and_buffers(model1, model2): + params1 = dict(model1.named_parameters()) + params2 = dict(model2.named_parameters()) + assert len(params1) == len(params2) + for name, param in params1.items(): + assert torch.allclose(param.to(torch.bfloat16), params2[name].to(torch.bfloat16)), ( + f"Parameters {name} are not close, {param} != {params2[name]}" + ) + buffers1 = dict(model1.named_buffers()) + buffers2 = dict(model2.named_buffers()) + assert len(buffers1) == len(buffers2) + for name, buffer in buffers1.items(): + assert torch.allclose(buffer.to(torch.bfloat16), buffers2[name].to(torch.bfloat16)), ( + f"Buffers {name} are not close, {buffer} != {buffers2[name]}" + ) + + +def _fuse_layers(rank, size, quant_config): + with patch_fsdp_mp_dtypes(): + # Initialize model + model = SmallQKVModel(dim=32).to("cuda") + non_fsdp_model = SmallQKVModel(dim=32).to("cuda") + non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict())) + model.eval() + non_fsdp_model.eval() + + _compare_parameters_and_buffers(model, non_fsdp_model) + + # Create calibration data ONCE + calib_data = torch.randn(1, 32, device="cuda") + + def calib_fn(x): + return x(calib_data) + + # Shard model + fully_shard(model) + torch.distributed.barrier() + + # Quantize model + mtq.quantize(model, quant_config, calib_fn) + mtq.quantize(non_fsdp_model, quant_config, calib_fn) + + torch.distributed.barrier() + + model.apply_embed = True + non_fsdp_model.apply_embed = True + + requantize_resmooth_fused_llm_layers(model) + requantize_resmooth_fused_llm_layers(non_fsdp_model) + + torch.distributed.barrier() + + # Unshard model + model.unshard() + + _compare_parameters_and_buffers(model, non_fsdp_model) + + +def _export_quantized_weight_test(rank, size, quant_config): + import copy + + from torch.distributed._composable.fsdp import fully_shard + + with patch_fsdp_mp_dtypes(): + # Initialize model + model = SmallQKVModel(dim=32).to("cuda") + non_fsdp_model = SmallQKVModel(dim=32).to("cuda") + non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict())) + model.eval() + non_fsdp_model.eval() + _compare_parameters_and_buffers(model, non_fsdp_model) + + # Create calibration data ONCE + calib_data = torch.randn(1, 32, device="cuda") + + def calib_fn(x): + return x(calib_data) + + # Shard model + fully_shard(model) + torch.distributed.barrier() + + # Quantize model + mtq.quantize(model, quant_config, calib_fn) + mtq.quantize(non_fsdp_model, quant_config, calib_fn) + + torch.distributed.barrier() + + model.apply_embed = True + non_fsdp_model.apply_embed = True + + requantize_resmooth_fused_llm_layers(model) + requantize_resmooth_fused_llm_layers(non_fsdp_model) + + torch.distributed.barrier() + + for name, sub_module in model.named_modules(): + if is_quantlinear(sub_module): + with fsdp2_aware_weight_update(model, sub_module): + _export_quantized_weight(sub_module, torch.float16) + + for name, sub_module in non_fsdp_model.named_modules(): + if is_quantlinear(sub_module): + with fsdp2_aware_weight_update(non_fsdp_model, sub_module): + _export_quantized_weight(sub_module, torch.float16) + + torch.distributed.barrier() + # Unshard model + model.unshard() + + _compare_parameters_and_buffers(model, non_fsdp_model) + + +@pytest.mark.parametrize("device_count", get_device_counts()) +def test_fsdp2_weight_compress_context_for_export(device_count): + spawn_multiprocess_job( + size=device_count, + job=_compress_weight_test, + backend="nccl", + ) + + +@pytest.mark.parametrize("device_count", get_device_counts()) +def test_fsdp2_weight_update_context_for_export(device_count): + spawn_multiprocess_job( + size=device_count, + job=_update_weight_test, + backend="nccl", + ) + + +@pytest.mark.parametrize( + "quant_config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.INT4_AWQ_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.NVFP4_DEFAULT_CFG, + mtq.NVFP4_AWQ_LITE_CFG, + # mtq.W4A8_AWQ_BETA_CFG, #TODO: Fix unit test for this case + # mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, #TODO: Fix unit test for this case + mtq.W4A8_MXFP4_FP8_CFG, + mtq.NVFP4_MLP_ONLY_CFG, + ], +) +@pytest.mark.parametrize("device_count", get_device_counts()) +def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config): + spawn_multiprocess_job( + size=device_count, + job=partial(_fuse_layers, quant_config=quant_config), + backend="nccl", + ) + + +@pytest.mark.parametrize( + "quant_config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.INT4_AWQ_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.NVFP4_DEFAULT_CFG, + mtq.NVFP4_AWQ_LITE_CFG, + # mtq.W4A8_AWQ_BETA_CFG, #TODO: Fix unit test for this case + # mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, #TODO: Fix unit test for this case + mtq.W4A8_MXFP4_FP8_CFG, + mtq.NVFP4_MLP_ONLY_CFG, + ], +) +@pytest.mark.parametrize("device_count", get_device_counts()) +def test_fsdp2_weight_update_context_for_export_quantized_weight(device_count, quant_config): + spawn_multiprocess_job( + size=device_count, + job=partial(_export_quantized_weight_test, quant_config=quant_config), + backend="nccl", + )