diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index a89e99355..4ac8b7f02 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -20,12 +20,20 @@ import sys import warnings from pathlib import Path +from typing import Any import torch import transformers from accelerate import infer_auto_device_map, init_empty_weights from accelerate.utils import get_max_memory -from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizerBase, + ProcessorMixin, +) try: from huggingface_hub import snapshot_download @@ -33,7 +41,7 @@ snapshot_download = None import modelopt.torch.quantization as mtq -from modelopt.torch.utils.image_processor import MllamaImageProcessor +from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] @@ -127,52 +135,50 @@ def build_quant_cfg( qformat, kv_cache_qformat, awq_block_size, - auto_quantize, model_type, quant_cfg_choices, kv_quant_cfg_choices, -): +) -> dict[str, Any]: quant_cfg = {} - if not auto_quantize: - assert qformat in quant_cfg_choices, ( - f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache" - ) + assert qformat in quant_cfg_choices, ( + f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache" + ) - quant_cfg = quant_cfg_choices[qformat] - - if "awq" in qformat: - quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) - weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] - if isinstance(weight_quantizer, list): - weight_quantizer = weight_quantizer[0] - # If awq_block_size argument is provided, update weight_quantizer - if awq_block_size: - weight_quantizer["block_sizes"][-1] = awq_block_size - - # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models - if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: - quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} - - enable_quant_kv_cache = kv_cache_qformat != "none" - print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") - - # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. - if enable_quant_kv_cache: - quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant( - quant_cfg, - getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"], - ) + quant_cfg = quant_cfg_choices[qformat] + + if "awq" in qformat: + quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) + weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] + if isinstance(weight_quantizer, list): + weight_quantizer = weight_quantizer[0] + # If awq_block_size argument is provided, update weight_quantizer + if awq_block_size: + weight_quantizer["block_sizes"][-1] = awq_block_size + + # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models + if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: + quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} + + enable_quant_kv_cache = kv_cache_qformat != "none" + print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") + + # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. + if enable_quant_kv_cache: + quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant( + quant_cfg, + getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"], + ) - # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. - if model_type == "gemma" and "int8_sq" in qformat: - quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} + # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. + if model_type == "gemma" and "int8_sq" in qformat: + quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} - if model_type == "phi4mm": - # Only quantize the language model - quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} - quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} - quant_cfg["quant_cfg"]["*image*"] = {"enable": False} - quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + if model_type == "phi4mm": + # Only quantize the language model + quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} + quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} + quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} if model_type == "qwen3next" and qformat == "nvfp4": # Disable the attention projection layers to retain accuracy @@ -191,7 +197,7 @@ def is_speculative(hf_config): ) -def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs): +def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs) -> PreTrainedTokenizerBase: print(f"Initializing tokenizer from {ckpt_path}") if "vila" in ckpt_path.lower(): @@ -212,8 +218,12 @@ def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs): def get_processor( - ckpt_path, model_type, device=None, trust_remote_code=False, attn_implementation=None -): + ckpt_path, + model_type, + device: torch.device = "auto", + trust_remote_code=False, + attn_implementation=None, +) -> BaseImageProcessor | ProcessorMixin | None: """ Returns a :class:`modelopt.torch.utils.image_processor.MllamaImageProcessor` object. """ @@ -248,6 +258,8 @@ def get_processor( return MllamaImageProcessor(processor, device) + return None + def get_dtype(dtype): if dtype == "bf16": diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 57f0b5a89..a9862a742 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -32,12 +32,15 @@ is_nemotron_vl, run_nemotron_vl_preview, ) +from torch.utils.data import DataLoader from transformers import ( AutoConfig, AutoModelForCausalLM, AutoProcessor, PreTrainedTokenizer, + PreTrainedTokenizerBase, PreTrainedTokenizerFast, + ProcessorMixin, WhisperProcessor, ) @@ -59,7 +62,7 @@ get_max_batch_size, get_supported_datasets, ) -from modelopt.torch.utils.image_processor import MllamaImageProcessor +from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader @@ -94,22 +97,82 @@ mto.enable_huggingface_checkpointing() +def make_calib_dataloader( + args: argparse.Namespace, + language_model: torch.nn.Module, + processor: BaseImageProcessor | ProcessorMixin | None, + tokenizer: PreTrainedTokenizerBase | None, + device: torch.device, + model_type: str | None, +) -> tuple[DataLoader, str | None]: + calib_dataloader = None + first_text_speech_dataset = None + if model_type == "mllama": + assert processor is not None and isinstance(processor, MllamaImageProcessor), ( + "The MllamaImageProcessor must be set." + ) + assert len(args.calib_size) == 1, ( + "mllama only supports one dataset for calibration, can extend this in the future" + ) + calib_dataloader = get_vlm_dataset_dataloader( + dataset_name=args.dataset[0] if args.dataset else "scienceqa", + processor=processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + ) + elif model_type == "whisper": + assert processor is not None and isinstance(processor, WhisperProcessor), ( + "The AutoProcessor must be set." + ) + assert len(args.calib_size) == 1, ( + "whisper only supports one dataset for calibration, can extend this in the future" + ) + calib_dataloader, first_text_speech_dataset = get_speech_dataset_dataloader( + dataset_name=args.dataset[0] if args.dataset else "peoples_speech", + processor=processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + device=device, + dtype=language_model.dtype, + ) + else: + assert tokenizer is not None and isinstance( + tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) + ), "The PreTrainedTokenizer must be set" + # Labels are only needed for gradient-based auto_quantize + include_labels = ( + args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" + ) + calib_dataloader = get_dataset_dataloader( + dataset_name=args.dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_samples=args.calib_size, + device=device, + include_labels=include_labels, + ) + return calib_dataloader, first_text_speech_dataset + + def auto_quantize( - model, - qformat, - calib_dataloader, - calibrate_loop, - auto_quantize_bits, - batch_size=1, + args: argparse.Namespace, + language_model: torch.nn.Module, + calib_dataloader: DataLoader, auto_quantize_method="gradient", auto_quantize_score_size=128, auto_quantize_checkpoint=None, ): - qformat_list = qformat.split(",") + """Auto search quantization of multiple formats.""" + + assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( + "Auto Quantization is not supported for pipeline parallel size > 1" + ) + + qformat_list = args.qformat.split(",") assert qformat_list, "No quantization formats provided" # Check if all provided quantization formats are supported assert all( - qformat + args.qformat in [ "fp8", "int8_sq", @@ -122,7 +185,7 @@ def auto_quantize( "w4a8_mxfp4_fp8", "nvfp4_mlp_only", ] - for qformat in qformat_list + for args.qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" def loss_func(output, data): @@ -143,9 +206,9 @@ def forward_step(model, batch): f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" ) - model, _ = mtq.auto_quantize( - model, - constraints={"effective_bits": auto_quantize_bits}, + language_model, _ = mtq.auto_quantize( + language_model, + constraints={"effective_bits": args.auto_quantize_bits}, data_loader=calib_dataloader, forward_step=forward_step, loss_func=loss_func, # Only used for gradient-based method @@ -153,7 +216,9 @@ def forward_step(model, batch): quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list], num_calib_steps=len(calib_dataloader), # AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration. - num_score_steps=min(len(calib_dataloader), max(auto_quantize_score_size // batch_size, 1)), + num_score_steps=min( + len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1) + ), verbose=True, # Disable all default disabled layers such as lm_head, mlp.gate, router etc. disabled_layers=list(_default_disabled_quantizer_cfg.keys()), @@ -161,6 +226,7 @@ def forward_step(model, batch): checkpoint=auto_quantize_checkpoint, ) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) # We need to explicitly calibrate for kv cache quantization enable_quant_kv_cache = args.kv_cache_qformat != "none" print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") @@ -169,110 +235,22 @@ def forward_step(model, batch): kv_cache_quant_cfg.pop("default") # keep other quantizers from auto_quantize mtq.set_quantizer_by_cfg( - model, + language_model, quant_cfg=kv_cache_quant_cfg, ) # Lets calibrate only the quantizers for kv cache quantization this time. Let's disable all others. with mtq.set_quantizer_by_cfg_context( - model, {"*": {"enable": False}, **kv_cache_quant_cfg} + language_model, {"*": {"enable": False}, **kv_cache_quant_cfg} ): - mtq.calibrate(model, algorithm="max", forward_loop=calibrate_loop) - return model - - -def quantize_model(model, quant_cfg, args, calib_dataloader=None, calibration_only=False): - # The calibration loop for the model can be setup using the modelopt API. - # - # Example usage: - # from modelopt.torch.utils.dataset_utils import create_forward_loop - # model = ... # Initialize the model - # tokenizer = ... # Initialize the tokenizer - # quant_cfg = ... # Setup quantization configuration - # forward_loop = create_forward_loop(model=model, dataset_name="cnn_dailymail", tokenizer=tokenizer) - # mtq.quantize(model, quant_cfg, forward_loop=forward_loop) - - # The calibrate_loop is a custom defined method to run the model with the input data. - # The basic version looks like: - # - # def calibrate_loop(model, dataloader): - # for data in dataloader: - # model(**data) - # - # We also provided a util method to generate the forward_loop with additional error handlings. - - use_calibration = args.auto_quantize_bits or need_calibration(quant_cfg) - - if not use_calibration: - warnings.warn("Dynamic quantization. Calibration skipped.") - calibrate_loop = create_forward_loop(dataloader=calib_dataloader) if use_calibration else None - - assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( - "Auto Quantization is not supported for pipeline parallel size > 1" - ) + mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop) + return language_model - print("Starting quantization...") - start_time = time.time() - if args.auto_quantize_bits: - model = auto_quantize( - model, - args.qformat, - calib_dataloader, - calibrate_loop, - args.auto_quantize_bits, - args.batch_size, - args.auto_quantize_method, - args.auto_quantize_score_size, - args.auto_quantize_checkpoint, - ) - elif calibration_only: - model = mtq.calibrate(model, quant_cfg["algorithm"], forward_loop=calibrate_loop) - else: - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - end_time = time.time() - print(f"Quantization done. Total time used: {end_time - start_time}s") - return model - - -def main(args): - if not torch.cuda.is_available(): - raise OSError("GPU is required for inference.") - - random.seed(RAND_SEED) - np.random.seed(RAND_SEED) - - # launch a memory monitor to read the currently used GPU memory. - launch_memory_monitor() - - # Force eager execution for all model types. - torch.compiler.set_stance("force_eager") - - # Check that only one quantization format is provided for non auto_quant case - if not args.auto_quantize_bits: - assert len(args.qformat.split(",")) == 1, ( - "Quantization supports only one quantization format." - ) - - if not args.auto_quantize_bits: - assert ( - args.qformat - in [ - "int8_wo", - "int4_awq", - "fp8", - "nvfp4", - "nvfp4_awq", - "w4a8_awq", - "fp8_pb_wo", - "w4a8_mxfp4_fp8", - "nvfp4_mlp_only", - ] - or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES - ), f"Quantization format {args.qformat} not supported for HF export path" +def load_model(args: argparse.Namespace): # If low memory mode is enabled, we compress the model while loading the HF checkpoint. calibration_only = False if not args.low_memory_mode: - model = get_model( + full_model = get_model( args.pyt_ckpt_path, args.device, gpu_mem_percentage=args.gpu_max_mem_percentage, @@ -287,7 +265,8 @@ def main(args): quant_cfg = QUANT_CFG_CHOICES[args.qformat] if args.kv_cache_qformat != "none": quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( - quant_cfg, getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] + quant_cfg, + getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], ) # Do not use real quant GEMM so the calibration can be more accurate. @@ -297,25 +276,21 @@ def main(args): model_kwargs = {"trust_remote_code": args.trust_remote_code} if args.attn_implementation is not None: model_kwargs["attn_implementation"] = args.attn_implementation - model = AutoModelForCausalLM.from_pretrained( + full_model = AutoModelForCausalLM.from_pretrained( args.pyt_ckpt_path, **model_kwargs, ) calibration_only = True - model_is_already_quantized = is_quantized(model) - model_type = get_model_type(model) + model_type = get_model_type(full_model) - device = model.device - if hasattr(model, "model"): - device = model.model.device + device = full_model.device + if hasattr(full_model, "model"): + device = full_model.model.device processor = None tokenizer = None - - full_model = model - - # Detect if this is a Nemotron VL model using architecture-based detection - is_nemotron_vl_model = is_nemotron_vl(full_model) + language_model = full_model + default_padding_side = None if model_type == "mllama": processor = get_processor( @@ -327,7 +302,10 @@ def main(args): ) elif model_type == "whisper": processor = get_processor( - args.pyt_ckpt_path, model_type, device, trust_remote_code=args.trust_remote_code + args.pyt_ckpt_path, + model_type, + device, + trust_remote_code=args.trust_remote_code, ) else: if args.dataset is None: @@ -364,258 +342,130 @@ def main(args): mtq.quantize(module, disabled_quant_cfg, forward_loop=None) memo.add(module) - model = language_model - model_type = get_model_type(model) + model_type = get_model_type(language_model) if model_type == "phi4mm": warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.") - if args.sparsity_fmt != "dense": - if args.batch_size == 0: - # Sparse algorithm takes more GPU memory so we reduce the batch_size by 4. - args.batch_size = max(get_max_batch_size(model) // 4, 1) - args.batch_size = min(args.batch_size, sum(args.calib_size)) + return ( + full_model, + language_model, + model_type, + calibration_only, + processor, + tokenizer, + default_padding_side, + device, + ) - print(f"Use calib batch_size {args.batch_size}") - # Different calibration datasets are also available, e.g., "pile" and "wikipedia" - # Please also check the docstring for the datasets available - assert tokenizer is not None and isinstance( - tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) - ), "The PreTrainedTokenizer must be set" - calib_dataloader = get_dataset_dataloader( - dataset_name=args.dataset, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_samples=args.calib_size, - max_sample_length=args.calib_seq, - device=device, - ) - model = mts.sparsify( - model, - args.sparsity_fmt, - config={"data_loader": calib_dataloader, "collect_func": lambda x: x}, +def sparsity_main( + args: argparse.Namespace, + full_model: torch.nn.Module, + tokenizer: PreTrainedTokenizerBase | None, + device: torch.device, +): + if args.batch_size == 0: + # Sparse algorithm takes more GPU memory so we reduce the batch_size by 4. + args.batch_size = max(get_max_batch_size(full_model) // 4, 1) + args.batch_size = min(args.batch_size, sum(args.calib_size)) + + print(f"Use calib batch_size {args.batch_size}") + + # Different calibration datasets are also available, e.g., "pile" and "wikipedia" + # Please also check the docstring for the datasets available + assert tokenizer is not None and isinstance( + tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) + ), "The PreTrainedTokenizer must be set" + calib_dataloader = get_dataset_dataloader( + dataset_name=args.dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_samples=args.calib_size, + max_sample_length=args.calib_seq, + device=device, + ) + full_model = mts.sparsify( + full_model, + args.sparsity_fmt, + config={"data_loader": calib_dataloader, "collect_func": lambda x: x}, + ) + mts.export(full_model) + + +def mono_quantize( + args: argparse.Namespace, + quant_cfg: dict[str, Any], + full_model: torch.nn.Module, + language_model: torch.nn.Module, + model_type: str | None, + calibration_only: bool, + calib_dataloader: DataLoader, + is_nemotron_vl_model: bool, +): + """Plain quantization of the given language model to a single quantization configuration.""" + + model_is_already_quantized = is_quantized(language_model) + + if "awq" in args.qformat: + print( + "\n####\nAWQ calibration could take longer than other calibration methods. " + "Consider reducing calib_size to reduce calibration time.\n####\n" ) - mts.export(model) - if args.auto_quantize_bits or args.qformat in QUANT_CFG_CHOICES: - if "awq" in args.qformat: - print( - "\n####\nAWQ calibration could take longer than other calibration methods. " - "Consider reducing calib_size to reduce calibration time.\n####\n" - ) + # For Nemotron VL models, disable quantization of vision components + if is_nemotron_vl_model: + print("Disabling quantization for vision components in Nemotron VL model") + quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + # Also disable radio model components specifically + quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} + quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} - if args.batch_size == 0: - # Calibration/sparsification will actually take much more memory than regular inference - # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio - # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. - sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 - # Whisper model expects mel-spectrogram input features of length 3000 - # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) - # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float - # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() - if model_type == "whisper": - max_sample_length = 3000 - num_mel_bins = model.config.num_mel_bins - sample_input_single_batch = ( - torch.ones([1, num_mel_bins, max_sample_length], dtype=model.dtype).to( - model.device - ) - * 100 - ) - else: - sample_input_single_batch = None + if not model_is_already_quantized or calibration_only: + if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": + print("Applying nvfp4 quantization (MoE only) for gpt-oss") - run_auto_quant = args.auto_quantize_bits is not None + # quantize the model - args.batch_size = get_max_batch_size( - model, - max_sample_length=args.calib_seq, - sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, - sample_input_single_batch=sample_input_single_batch, - enable_grad=run_auto_quant, - ) - args.batch_size = min(args.batch_size, sum(args.calib_size)) + use_calibration = need_calibration(quant_cfg) - print(f"Use calib batch_size {args.batch_size}") + if not use_calibration: + warnings.warn("Dynamic quantization. Calibration skipped.") + calibrate_loop = ( + create_forward_loop(dataloader=calib_dataloader) if use_calibration else None + ) - calib_dataloader = None - if model_type == "mllama": - assert processor is not None and isinstance(processor, MllamaImageProcessor), ( - "The MllamaImageProcessor must be set." - ) - assert len(args.calib_size) == 1, ( - "mllama only supports one dataset for calibration, can extend this in the future" - ) - calib_dataloader = get_vlm_dataset_dataloader( - dataset_name=args.dataset[0] if args.dataset else "scienceqa", - processor=processor, - batch_size=args.batch_size, - num_samples=args.calib_size[0], - ) - elif model_type == "whisper": - assert processor is not None and isinstance(processor, WhisperProcessor), ( - "The AutoProcessor must be set." - ) - assert len(args.calib_size) == 1, ( - "whisper only supports one dataset for calibration, can extend this in the future" - ) - calib_dataloader, first_text = get_speech_dataset_dataloader( - dataset_name=args.dataset[0] if args.dataset else "peoples_speech", - processor=processor, - batch_size=args.batch_size, - num_samples=args.calib_size[0], - device=device, - dtype=model.dtype, + if calibration_only: + language_model = mtq.calibrate( + language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop ) else: - assert tokenizer is not None and isinstance( - tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) - ), "The PreTrainedTokenizer must be set" - # Labels are only needed for gradient-based auto_quantize - include_labels = ( - args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" - ) - calib_dataloader = get_dataset_dataloader( - dataset_name=args.dataset, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_samples=args.calib_size, - device=device, - include_labels=include_labels, - ) - - quant_cfg = build_quant_cfg( - args.qformat, - args.kv_cache_qformat, - args.awq_block_size, - args.auto_quantize_bits, - model_type, - QUANT_CFG_CHOICES, - KV_QUANT_CFG_CHOICES, - ) + language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop) - # For Nemotron VL models, disable quantization of vision components + # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: - print("Disabling quantization for vision components in Nemotron VL model") - quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} - quant_cfg["quant_cfg"]["*image*"] = {"enable": False} - # Also disable radio model components specifically - quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} - quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} - - if not model_is_already_quantized or calibration_only: - # Only run single sample for preview - input_ids = next(iter(calib_dataloader))[ - "input_features" if model_type == "whisper" else "input_ids" - ][0:1] - - # Generate preview before quantization - if is_nemotron_vl_model and tokenizer is not None: - generated_ids_before_ptq = run_nemotron_vl_preview( - full_model, - tokenizer, - input_ids, - args.pyt_ckpt_path, - "before quantization", - allow_fallback=True, - ) - else: - # Standard generation for non-Nemotron VL models - generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) - if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": - print("Applying nvfp4 quantization (MoE only) for gpt-oss") - - # quantize the model - model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only) - - # For VL models, update full_model to use the quantized language model - if is_nemotron_vl_model: - language_model_lineage = get_language_model_from_vl(full_model) - if language_model_lineage is not None: - print("Updating full_model with quantized language_model...") - language_model_lineage[-2].language_model = model - - if args.verbose: - mtq.print_quant_summary(full_model) - - # Run some samples - torch.cuda.empty_cache() - generated_ids_after_ptq = None - if model_type != "llama4" and not is_nemotron_vl_model: - # Our fake quantizer may not be fully compatible with torch.compile. - generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100) - elif is_nemotron_vl_model and tokenizer is not None: - generated_ids_after_ptq = run_nemotron_vl_preview( - full_model, - tokenizer, - input_ids, - args.pyt_ckpt_path, - "after quantization", - allow_fallback=False, - ) - else: - warnings.warn( - "Llama4 Maverick generation after quantization has a bug. Skipping generation sample." - ) - - def input_decode(input_ids): - if processor is not None and isinstance(processor, MllamaImageProcessor): - return processor.tokenizer.batch_decode(input_ids) - elif processor is not None and isinstance(processor, WhisperProcessor): - return first_text - elif tokenizer is not None: - return tokenizer.batch_decode(input_ids) - else: - raise ValueError("The processor or tokenizer must be set") - - def output_decode(generated_ids, input_shape): - if is_enc_dec(model_type): - if processor is not None and isinstance(processor, WhisperProcessor): - return processor.tokenizer.batch_decode( - generated_ids, skip_special_tokens=True - )[0] - elif tokenizer is not None: - return tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - elif processor is not None and isinstance(processor, MllamaImageProcessor): - return processor.tokenizer.batch_decode(generated_ids[:, input_shape:]) - elif tokenizer is not None: - return tokenizer.batch_decode(generated_ids[:, input_shape:]) - else: - raise ValueError("The processor or tokenizer must be set") - - if generated_ids_after_ptq is not None: - print("--------") - if is_nemotron_vl_model: - # For Nemotron VL models, generated_ids are text strings from model.chat() - print("Nemotron VL model text-only generation results:") - print(f"Text response before quantization: {generated_ids_before_ptq}") - print("--------") - print(f"Text response after quantization: {generated_ids_after_ptq}") - print("--------") - print("Note: Additional VL tests with images were run separately above") - else: - # For regular LLMs, generated_ids are token tensors that need decoding - print(f"example test input: {input_decode(input_ids)}") - print("--------") - print( - f"example outputs before ptq: {output_decode(generated_ids_before_ptq, input_ids.shape[1])}" - ) - print("--------") - print( - f"example outputs after ptq: {output_decode(generated_ids_after_ptq, input_ids.shape[1])}" - ) - else: - warnings.warn("Skipping quantization: model is already quantized.") + language_model_lineage = get_language_model_from_vl(full_model) + if language_model_lineage is not None: + print("Updating full_model with quantized language_model...") + language_model_lineage[-2].language_model = language_model else: - assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton" - print(f"qformat: {args.qformat}. No quantization applied, export {device} model") + warnings.warn("Skipping quantization: model is already quantized.") + +def export_quantized( + args: argparse.Namespace, + full_model: torch.nn.Module, + language_model: torch.nn.Module, + model_type: str | None, + tokenizer: PreTrainedTokenizerBase | None, + default_padding_side, +): with torch.inference_mode(): if model_type is None: - print(f"Unknown model type {type(model).__name__}. Continue exporting...") - model_type = f"unknown:{type(model).__name__}" + print(f"Unknown model type {type(language_model).__name__}. Continue exporting...") + model_type = f"unknown:{type(language_model).__name__}" export_path = args.export_path @@ -644,12 +494,11 @@ def output_decode(generated_ids, input_shape): print("This is normal for some VLM architectures that don't use AutoProcessor") if model_type == "mllama": - full_model_config = model.config - model = model.language_model + full_model_config = full_model.config # TRT-LLM expects both the vision_config and text_config to be set for export. - setattr(model.config, "vision_config", full_model_config.vision_config) - setattr(model.config, "text_config", full_model_config.text_config) - setattr(model.config, "architectures", full_model_config.architectures) + setattr(full_model.config, "vision_config", full_model_config.vision_config) + setattr(full_model.config, "text_config", full_model_config.text_config) + setattr(full_model.config, "architectures", full_model_config.architectures) start_time = time.time() if ( @@ -662,10 +511,10 @@ def output_decode(generated_ids, input_shape): ) # Move meta tensor back to device before exporting. - remove_hook_from_module(model, recurse=True) + remove_hook_from_module(language_model, recurse=True) export_tensorrt_llm_checkpoint( - model, + language_model, model_type, export_dir=export_path, inference_tensor_parallel=args.inference_tensor_parallel, @@ -701,11 +550,265 @@ def output_decode(generated_ids, input_shape): end_time = time.time() print( - f"Quantized model exported to :{export_path}. Total time used {end_time - start_time}s" + f"Quantized model exported to: {export_path}. Total time used {end_time - start_time}s" ) -if __name__ == "__main__": +def pre_quantize( + args: argparse.Namespace, + full_model: torch.nn.Module, + model_type: str | None, + tokenizer: PreTrainedTokenizerBase | None, + calib_dataloader: DataLoader, + is_nemotron_vl_model: bool, +): + """ + Processing before the quantization. + + Currently we run one round of generation for a sample prompt, to be compared with + post-quantize generation. + + """ + # Only run single sample for preview + preview_input_ids = next(iter(calib_dataloader))[ + "input_features" if model_type == "whisper" else "input_ids" + ][0:1] + + # Generate preview before quantization + if is_nemotron_vl_model and tokenizer is not None: + generated_ids_before_ptq = run_nemotron_vl_preview( + full_model, + tokenizer, + preview_input_ids, + args.pyt_ckpt_path, + "before quantization", + allow_fallback=True, + ) + else: + # Standard generation for non-Nemotron VL models + generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": + print("Applying nvfp4 quantization (MoE only) for gpt-oss") + + return preview_input_ids, generated_ids_before_ptq + + +def post_quantize( + args: argparse.Namespace, + full_model: torch.nn.Module, + model_type: str | None, + tokenizer: PreTrainedTokenizerBase | None, + processor: BaseImageProcessor | ProcessorMixin | None, + preview_input_ids, + generated_ids_before_ptq, + is_nemotron_vl_model, + first_text_speech_dataset, +): + """ + Processing after the quantization. + + Currently we run one round of generation using the quantized model for a sample prompt, + and compare it with pre-quantize generation. + + """ + + if args.verbose: + mtq.print_quant_summary(full_model) + + # Run some samples + torch.cuda.empty_cache() + generated_ids_after_ptq = None + if model_type != "llama4" and not is_nemotron_vl_model: + # Our fake quantizer may not be fully compatible with torch.compile. + generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) + elif is_nemotron_vl_model and tokenizer is not None: + generated_ids_after_ptq = run_nemotron_vl_preview( + full_model, + tokenizer, + preview_input_ids, + args.pyt_ckpt_path, + "after quantization", + allow_fallback=False, + ) + else: + warnings.warn( + "Llama4 Maverick generation after quantization has a bug. Skipping generation sample." + ) + + def input_decode(input_ids): + if processor is not None and isinstance(processor, MllamaImageProcessor): + return processor.tokenizer.batch_decode(input_ids) + elif processor is not None and isinstance(processor, WhisperProcessor): + return first_text_speech_dataset + elif tokenizer is not None: + return tokenizer.batch_decode(input_ids) + else: + raise ValueError("The processor or tokenizer must be set") + + def output_decode(generated_ids, input_shape): + if is_enc_dec(model_type): + if processor is not None and isinstance(processor, WhisperProcessor): + return processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + elif tokenizer is not None: + return tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + elif processor is not None and isinstance(processor, MllamaImageProcessor): + return processor.tokenizer.batch_decode(generated_ids[:, input_shape:]) + elif tokenizer is not None: + return tokenizer.batch_decode(generated_ids[:, input_shape:]) + else: + raise ValueError("The processor or tokenizer must be set") + + if generated_ids_after_ptq is not None: + print("--------") + if is_nemotron_vl_model: + # For Nemotron VL models, generated_ids are text strings from model.chat() + print("Nemotron VL model text-only generation results:") + print(f"Text response before quantization: {generated_ids_before_ptq}") + print("--------") + print(f"Text response after quantization: {generated_ids_after_ptq}") + print("--------") + print("Note: Additional VL tests with images were run separately above") + else: + # For regular LLMs, generated_ids are token tensors that need decoding + print(f"example test input: {input_decode(preview_input_ids)}") + print("--------") + print( + f"example outputs before ptq: {output_decode(generated_ids_before_ptq, preview_input_ids.shape[1])}" + ) + print("--------") + print( + f"example outputs after ptq: {output_decode(generated_ids_after_ptq, preview_input_ids.shape[1])}" + ) + + +def quantize_main( + args: argparse.Namespace, + full_model: torch.nn.Module, + language_model: torch.nn.Module, + model_type: str | None, + calibration_only: bool, + processor: BaseImageProcessor | ProcessorMixin | None, + tokenizer: PreTrainedTokenizerBase | None, + default_padding_side, + device: torch.device, +): + if args.batch_size == 0: + # Calibration/sparsification will actually take much more memory than regular inference + # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio + # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. + sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 + # Whisper model expects mel-spectrogram input features of length 3000 + # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) + # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float + # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() + if model_type == "whisper": + max_sample_length = 3000 + num_mel_bins = language_model.config.num_mel_bins + sample_input_single_batch = ( + torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( + language_model.device + ) + * 100 + ) + else: + sample_input_single_batch = None + + run_auto_quant = args.auto_quantize_bits is not None + + args.batch_size = get_max_batch_size( + language_model, + max_sample_length=args.calib_seq, + sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, + sample_input_single_batch=sample_input_single_batch, + enable_grad=run_auto_quant, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) + + print(f"Use calib batch_size {args.batch_size}") + + calib_dataloader, first_text_speech_dataset = make_calib_dataloader( + args, language_model, processor, tokenizer, device, model_type + ) + + # Detect if this is a Nemotron VL model using architecture-based detection + is_nemotron_vl_model = is_nemotron_vl(full_model) + + preview_input_ids, generated_ids_before_ptq = pre_quantize( + args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model + ) + + if args.auto_quantize_bits: + assert len(args.qformat.split(",")) > 1, ( + "Auto quantization needs multiple quantization format." + ) + + auto_quantize( + args, + language_model, + calib_dataloader, + ) + + else: + # mono quantization + assert len(args.qformat.split(",")) == 1, ( + "Plain quantization supports only one quantization format." + ) + + assert ( + args.qformat + in [ + "int8_wo", + "int4_awq", + "fp8", + "nvfp4", + "nvfp4_awq", + "w4a8_awq", + "fp8_pb_wo", + "w4a8_mxfp4_fp8", + "nvfp4_mlp_only", + ] + or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES + ), f"Plain quantization format {args.qformat} not supported for HF export path" + + quant_cfg = build_quant_cfg( + args.qformat, + args.kv_cache_qformat, + args.awq_block_size, + model_type, + QUANT_CFG_CHOICES, + KV_QUANT_CFG_CHOICES, + ) + + if args.qformat in QUANT_CFG_CHOICES: + mono_quantize( + args, + quant_cfg, + full_model, + language_model, + model_type, + calibration_only, + calib_dataloader, + is_nemotron_vl_model, + ) + else: + assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton" + print(f"qformat: {args.qformat}. No quantization applied, export {device} model") + + post_quantize( + args, + full_model, + model_type, + tokenizer, + processor, + preview_input_ids, + generated_ids_before_ptq, + is_nemotron_vl_model, + first_text_speech_dataset, + ) + export_quantized(args, full_model, language_model, model_type, tokenizer, default_padding_side) + + +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--pyt_ckpt_path", @@ -866,7 +969,53 @@ def output_decode(generated_ids, input_shape): ), ) - args = parser.parse_args() + return parser.parse_args() + + +def main(args: argparse.Namespace): + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + + # launch a memory monitor to read the currently used GPU memory. + launch_memory_monitor() + + # Force eager execution for all model types. + torch.compiler.set_stance("force_eager") + + ( + full_model, + language_model, + model_type, + calibration_only, + processor, + tokenizer, + default_padding_side, + device, + ) = load_model(args) + + if args.sparsity_fmt != "dense": + # Sparse + sparsity_main(args, full_model, tokenizer, device) + else: + # Quantize + quantize_main( + args, + full_model, + language_model, + model_type, + calibration_only, + processor, + tokenizer, + default_padding_side, + device, + ) + + +if __name__ == "__main__": + args = parse_args() if args.export_fmt != "hf": warnings.warn("Deprecated. --export_fmt forced to hf.") diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py index dc3c5e4b0..2ae7dde4a 100644 --- a/examples/llm_ptq/multinode_ptq.py +++ b/examples/llm_ptq/multinode_ptq.py @@ -332,7 +332,6 @@ def main(args): args.qformat, args.kv_cache_qformat, args.awq_block_size, - None, model_type, QUANT_CFG_CHOICES, KV_QUANT_CFG_CHOICES, diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index d4cf5049d..7908ec514 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -173,7 +173,7 @@ def get_dataset_dataloader( batch_size: int = 1, num_samples: int | list[int] = 512, max_sample_length: int = 512, - device: str | None = None, + device: torch.device | None = None, include_labels: bool = False, ) -> DataLoader: """Get a dataloader with the dataset name and toknizer of the target model. @@ -264,7 +264,7 @@ def get_max_batch_size( model: torch.nn.Module, max_sample_length: int = 512, sample_memory_usage_ratio: float = 1.0, - sample_input_single_batch: torch.Tensor = None, + sample_input_single_batch: torch.Tensor | None = None, enable_grad: bool = False, ): """Get the maximum batch size that can be used for the model.""" diff --git a/modelopt/torch/utils/image_processor.py b/modelopt/torch/utils/image_processor.py index 87960d54d..6374642e3 100644 --- a/modelopt/torch/utils/image_processor.py +++ b/modelopt/torch/utils/image_processor.py @@ -22,7 +22,7 @@ class BaseImageProcessor: """Base class for image processors.""" - def __init__(self, tokenizer, device="auto"): + def __init__(self, tokenizer, device="cuda"): """Constructor.""" self.tokenizer = tokenizer self.device = device diff --git a/modelopt/torch/utils/speech_dataset_utils.py b/modelopt/torch/utils/speech_dataset_utils.py index 0d414f7ec..a71d73773 100644 --- a/modelopt/torch/utils/speech_dataset_utils.py +++ b/modelopt/torch/utils/speech_dataset_utils.py @@ -79,12 +79,12 @@ def get_supported_speech_datasets() -> list[str]: def get_speech_dataset_dataloader( dataset_name: str = "peoples_speech", - processor: WhisperProcessor = None, + processor: WhisperProcessor | None = None, batch_size: int = 1, num_samples: int = 512, - device: str | None = None, + device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> DataLoader: +) -> tuple[DataLoader, str]: """Get a dataloader with the dataset name and processor of the target model. Args: