diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 065913464..49b83bc57 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,7 @@ Model Optimizer Changelog (Linux) - Support PTQ and fakequant in vLLM for fast evaluation of arbitrary quantization formats. See ``examples/vllm_serve`` for more details. - Add support for ``nemotron-post-training-dataset-v2`` and ``nemotron-post-training-dataset-v1`` in ``examples/llm_ptq``. Default to a mix of ``cnn_dailymail`` and ``nemotron-post-training-dataset-v2`` if no dataset is specified. - Allow specifying ``calib_seq`` in ``examples/llm_ptq`` to set the maximum sequence length for calibration. +- Support ``DeepSeek V3.2`` model quantization. See ``examples/deepseek`` for more details. **Documentation** diff --git a/examples/deepseek/.gitignore b/examples/deepseek/.gitignore index 08671abd4..9ff38b490 100644 --- a/examples/deepseek/.gitignore +++ b/examples/deepseek/.gitignore @@ -1 +1,2 @@ DeepSeek-V3/ +DeepSeek-V3.2-Exp/ diff --git a/examples/deepseek/README.md b/examples/deepseek/README.md index 34097f227..9724296be 100644 --- a/examples/deepseek/README.md +++ b/examples/deepseek/README.md @@ -6,7 +6,7 @@ This example will demonstrate the steps to quantize DeepSeek R1 model to FP4 and Due to the model size, currently it requires 8xH200 or 16xH100 to quantize the FP8 model, we will use 8xH200 as example. -### Convert the HF checkpoint for deepseek FP8 inference +## Convert the HF checkpoint for deepseek FP8 inference ```bash # set up variables to run the example @@ -14,26 +14,56 @@ export HF_FP8_CKPT={path_to_downloaded_hf_checkpoint} export DS_CKPT={path_to_save_converted_checkpoint} export FP4_QUANT_PATH={path_to_save_quantization_results} export HF_FP4_PATH={path_to_save_the_final_FP4_checkpoint} +``` + +### DeepSeek V3 R1 V3.1 -# download the FP8 checkpoint from Hugginface +```bash +# download the FP8 checkpoint from Hugginface. This is an example of DeepSeek-R1 huggingface-cli download deepseek-ai/DeepSeek-R1 --local-dir $HF_FP8_CKPT # clone DeepSeek-V3 (base model of R1) Github repository for FP8 inference, git clone https://github.com/deepseek-ai/DeepSeek-V3.git && cd DeepSeek-V3 && git checkout 1398800 +``` + +### DeepSeek V3.2 +```bash +# download the FP8 checkpoint from Hugginface. +huggingface-cli download deepseek-ai/DeepSeek-V3.2-Exp --local-dir $HF_FP8_CKPT + +# clone DeepSeek-V3.2 Github repository for FP8 inference, +git clone https://github.com/deepseek-ai/DeepSeek-V3.2-Exp.git && cd DeepSeek-V3.2-Exp && git checkout 3b99a53 + +# Install requirements +pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git +pip install -r inference/requirements.txt +``` + +### Convert the Checkpoint + +```bash # convert the HF checkpoint to a specific format for Deepseek python inference/convert.py --hf-ckpt-path $HF_FP8_CKPT --save-path $DS_CKPT --n-experts 256 --model-parallel 8 ``` -### Post-training quantization +## Post-training quantization + +### Run the calibration scripts -#### Run the calibration scripts +DeepSeek V3, R1, V3.1 ```bash torchrun --nproc-per-node 8 --master_port=12346 ptq.py --model_path $DS_CKPT --config DeepSeek-V3/inference/configs/config_671B.json --quant_cfg NVFP4_DEFAULT_CFG --output_path $FP4_QUANT_PATH ``` -#### Quantize the FP8 hf checkpoint to FP4 +DeepSeek V3.2 + +```bash +torchrun --nproc-per-node 8 --master_port=12346 ptq.py --model_path $DS_CKPT --config DeepSeek-V3.2-Exp/inference/config_671B_v3.2.json --quant_cfg NVFP4_DEFAULT_CFG --output_path $FP4_QUANT_PATH +``` + +### Quantize the FP8 hf checkpoint to FP4 We provide a one-step-script which will: diff --git a/examples/deepseek/ds_kernel.py b/examples/deepseek/ds_kernel.py new file mode 100644 index 000000000..8ac5a1fa9 --- /dev/null +++ b/examples/deepseek/ds_kernel.py @@ -0,0 +1,95 @@ +# MIT License + +# Copyright (c) 2023 DeepSeek + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import triton +import triton.language as tl + +"""Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py""" + + +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index 064071210..b091ddc0e 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -64,9 +64,21 @@ from modelopt.torch.utils.dataset_utils import get_dataset_dataloader from modelopt.torch.utils.distributed import ParallelState -sys.path.append(str(Path(__file__).resolve().parent / "DeepSeek-V3/inference")) -import model as deekseep_model -from kernel import act_quant, fp8_gemm, weight_dequant +DS_V3_PATH = Path(__file__).resolve().parent / "DeepSeek-V3/inference" +DS_V3_2_PATH = Path(__file__).resolve().parent / "DeepSeek-V3.2-Exp/inference" + +if DS_V3_2_PATH.exists(): + sys.path.append(str(DS_V3_2_PATH)) +elif DS_V3_PATH.exists(): + sys.path.append(str(DS_V3_PATH)) +else: + raise ValueError( + f"DeepSeek-V3 or DeepSeek-V3.2-Exp not found in {Path(__file__).resolve().parent}" + ) + +import model as deekseep_model # noqa: E402 +from ds_kernel import weight_dequant # noqa: E402 +from kernel import act_quant, fp8_gemm # noqa: E402 def monkey_patch_deepseek_model(): @@ -186,6 +198,26 @@ def _setup(self): self.kv_bmm_quantizer = TensorQuantizer() self.pe_bmm_quantizer = TensorQuantizer() + class CalibMoe(deekseep_model.MoE): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._setup() + + def _setup(self): + self._original_topk = self.gate.topk + self._original_topk_groups = self.gate.topk_groups + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Forward all tokens to all experts for calibration + self.gate.topk = self.n_routed_experts + self.gate.topk_groups = self.gate.n_groups + super().forward(x) + # Restore the original topk and topk_groups + self.gate.topk = self._original_topk + self.gate.topk_groups = self._original_topk_groups + + return super().forward(x) + mtq.register( original_cls=deekseep_model.RowParallelLinear, quantized_cls=QuantRowParallelLinear, @@ -196,6 +228,7 @@ def _setup(self): ) mtq.register(original_cls=deekseep_model.Linear, quantized_cls=QuantLinear) mtq.register(original_cls=deekseep_model.MLA, quantized_cls=QuantMLA) + mtq.register(original_cls=deekseep_model.MoE, quantized_cls=CalibMoe) def load_deepseek_model(model_config: str, model_path: str, batch_size: int): @@ -243,10 +276,10 @@ def ptq( ## create dataset device = next(model.parameters()).device calib_dataset = get_dataset_dataloader( - dataset_name="cnn_dailymail", + dataset_name=["cnn_dailymail", "nemotron-post-training-dataset-v2"], tokenizer=tokenizer, batch_size=batch_size, - num_samples=calib_size, + num_samples=[calib_size, calib_size], device=device, ) @@ -307,6 +340,13 @@ def state_dict_filter(state_dict): os.path.join(output_path, f"amax_dict_rank{rank}-mp{world_size}.pt"), ) + # if rank == 0: + # with open("expert_activation_counts.txt", "w") as f: + # for name, module in model.named_modules(): + # if isinstance(module, deekseep_model.MoE): + # counts = module.activated_expert_counts() + # f.writelines(f"{name}: {count}\n" for count in counts) + quant_config = get_quant_config(model.named_modules()) if enable_fp8_kvcache: diff --git a/examples/deepseek/quantize_fp8_to_nvfp4.sh b/examples/deepseek/quantize_fp8_to_nvfp4.sh index 8dd8f4fcd..ae24e2bfd 100755 --- a/examples/deepseek/quantize_fp8_to_nvfp4.sh +++ b/examples/deepseek/quantize_fp8_to_nvfp4.sh @@ -78,7 +78,9 @@ fi # Copy miscellaneous files to the quantized checkpoint mkdir -p $FP4_PATH -cp $FP8_HF_PATH/*.json $FP8_HF_PATH/*.py $FP4_PATH/ +cp $FP8_HF_PATH/*.json $FP4_PATH/ +cp $FP8_HF_PATH/*.py $FP4_PATH/ || true +cp -r $FP8_HF_PATH/assets $FP4_PATH/ || true # Run the quantization command echo "Running quantization..." diff --git a/examples/deepseek/quantize_to_nvfp4.py b/examples/deepseek/quantize_to_nvfp4.py index 542d09abd..d94f48fce 100644 --- a/examples/deepseek/quantize_to_nvfp4.py +++ b/examples/deepseek/quantize_to_nvfp4.py @@ -41,19 +41,15 @@ import glob import json import os -import sys -from pathlib import Path from typing import Any import torch +from ds_kernel import weight_dequant from safetensors.torch import load_file, save_file from tqdm import tqdm from modelopt.torch.quantization.qtensor import NVFP4QTensor -sys.path.append(str(Path(__file__).resolve().parent / "DeepSeek-V3/inference")) -from kernel import weight_dequant - def _remap_key(key_dict: dict[str, Any]): # renaming the module to match HF modeling @@ -155,7 +151,7 @@ def convert_fp8_ckpt_to_nvfp4( per_layer_quant_config, ): def amax_to_nvfp4_scaling_factor_2(amax): - return amax.float() / 6.0 / 448.0 + return amax.float() / (6.0 * 448.0) def amax_to_fp8_scaling_factor(amax): return amax.float() / 448.0 diff --git a/pyproject.toml b/pyproject.toml index 8ae14292d..77bb68428 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ extend-ignore = [ "*/_[a-zA-Z]*" = ["D"] # Private packages (_abc/*.py) or modules (_xyz.py) "*.ipynb" = ["D", "E501"] # Ignore missing docstrings or line length for Jupyter notebooks "modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style - +"examples/deepseek/ds_kernel.py" = ["N803", "N806", "E731"] # triton style [tool.ruff.lint.pycodestyle] max-line-length = 120 # Line length limit for comments and docstrings