-
Notifications
You must be signed in to change notification settings - Fork 183
[OMNIML-2857] Support the DeepSeek V3.2 model #435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
19722c9
4e4bf16
90865c3
9b64663
ea9190e
2223d4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| DeepSeek-V3/ | ||
| DeepSeek-V3.2-Exp/ |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,95 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # MIT License | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) 2023 DeepSeek | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # of this software and associated documentation files (the "Software"), to deal | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # in the Software without restriction, including without limitation the rights | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # copies of the Software, and to permit persons to whom the Software is | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # furnished to do so, subject to the following conditions: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # The above copyright notice and this permission notice shall be included in all | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # copies or substantial portions of the Software. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SOFTWARE. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
meenchen marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import triton | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import triton.language as tl | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @triton.jit | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
meenchen marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Dequantizes weights using the provided scaling factors and stores the result. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x_ptr (tl.pointer): Pointer to the quantized weights. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s_ptr (tl.pointer): Pointer to the scaling factors. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M (int): Number of rows in the weight matrix. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| N (int): Number of columns in the weight matrix. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BLOCK_SIZE (tl.constexpr): Size of the block for tiling. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pid_m = tl.program_id(axis=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pid_n = tl.program_id(axis=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n = tl.cdiv(N, BLOCK_SIZE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| offs = offs_m[:, None] * N + offs_n[None, :] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s = tl.load(s_ptr + pid_m * n + pid_n) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| y = x * s | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tl.store(y_ptr + offs, y, mask=mask) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Dequantizes the given weight tensor using the provided scale tensor. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x (torch.Tensor): The quantized weight tensor of shape (M, N). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_size (int, optional): The block size to use for dequantization. Defaults to 128. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.Tensor: The dequantized weight tensor of the same shape as `x`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Raises: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M, N = x.size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| y = torch.empty_like(x, dtype=torch.get_default_dtype()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return y | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+79
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Block-scale tensor shape can go out-of-bounds The docstring and lack of shape checks let callers size Please validate the shape up front (using ceil-div) and update the docstring accordingly, e.g.: @@
- s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size).
+ s (torch.Tensor): The scale tensor of shape (ceil_div(M, block_size), ceil_div(N, block_size)).
@@
- M, N = x.size()
+ M, N = x.size()
+ m_blocks = (M + block_size - 1) // block_size
+ n_blocks = (N + block_size - 1) // block_size
+ assert s.size() == (m_blocks, n_blocks), \
+ f"Expected s.shape == ({m_blocks}, {n_blocks}), got {tuple(s.size())}"This keeps the kernel within bounds and matches its launch configuration. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,9 +64,21 @@ | |
| from modelopt.torch.utils.dataset_utils import get_dataset_dataloader | ||
| from modelopt.torch.utils.distributed import ParallelState | ||
|
|
||
| sys.path.append(str(Path(__file__).resolve().parent / "DeepSeek-V3/inference")) | ||
| import model as deekseep_model | ||
| from kernel import act_quant, fp8_gemm, weight_dequant | ||
| DS_V3_PATH = Path(__file__).resolve().parent / "DeepSeek-V3/inference" | ||
| DS_V3_2_PATH = Path(__file__).resolve().parent / "DeepSeek-V3.2-Exp/inference" | ||
|
|
||
| if DS_V3_2_PATH.exists(): | ||
| sys.path.append(str(DS_V3_2_PATH)) | ||
| elif DS_V3_PATH.exists(): | ||
| sys.path.append(str(DS_V3_PATH)) | ||
| else: | ||
| raise ValueError( | ||
| f"DeepSeek-V3 or DeepSeek-V3.2-Exp not found in {Path(__file__).resolve().parent}" | ||
| ) | ||
|
|
||
| import model as deekseep_model # noqa: E402 | ||
| from ds_kernel import weight_dequant # noqa: E402 | ||
| from kernel import act_quant, fp8_gemm # noqa: E402 | ||
|
|
||
|
|
||
| def monkey_patch_deepseek_model(): | ||
|
|
@@ -186,6 +198,26 @@ def _setup(self): | |
| self.kv_bmm_quantizer = TensorQuantizer() | ||
| self.pe_bmm_quantizer = TensorQuantizer() | ||
|
|
||
| class CalibMoe(deekseep_model.MoE): | ||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self._setup() | ||
|
|
||
| def _setup(self): | ||
| self._original_topk = self.gate.topk | ||
| self._original_topk_groups = self.gate.topk_groups | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| # Forward all tokens to all experts for calibration | ||
| self.gate.topk = self.n_routed_experts | ||
| self.gate.topk_groups = self.gate.n_groups | ||
| super().forward(x) | ||
| # Restore the original topk and topk_groups | ||
| self.gate.topk = self._original_topk | ||
| self.gate.topk_groups = self._original_topk_groups | ||
|
|
||
| return super().forward(x) | ||
|
|
||
|
Comment on lines
+201
to
+220
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the double forward pass in CalibMoe. The
This double invocation means every forward pass during calibration runs inference twice through the MoE layer, which is extremely expensive for large models. Expected behavior: The forward pass should only route to all experts during calibration phase (when quantizers collect statistics), then switch to normal routing once calibration is complete. The current implementation runs both routing strategies on every call. Consider this fix: class CalibMoe(deekseep_model.MoE):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._setup()
def _setup(self):
- self._original_topk = self.gate.topk
- self._original_topk_groups = self.gate.topk_groups
+ # Route to all experts during calibration
+ self.gate.topk = self.n_routed_experts
+ self.gate.topk_groups = self.gate.n_groups
def forward(self, x: torch.Tensor) -> torch.Tensor:
- # Forward all tokens to all experts for calibration
- self.gate.topk = self.n_routed_experts
- self.gate.topk_groups = self.gate.n_groups
- super().forward(x)
- # Restore the original topk and topk_groups
- self.gate.topk = self._original_topk
- self.gate.topk_groups = self._original_topk_groups
-
return super().forward(x)This sets the all-expert routing once during setup and keeps it throughout calibration, eliminating the double forward pass. |
||
| mtq.register( | ||
| original_cls=deekseep_model.RowParallelLinear, | ||
| quantized_cls=QuantRowParallelLinear, | ||
|
|
@@ -196,6 +228,7 @@ def _setup(self): | |
| ) | ||
| mtq.register(original_cls=deekseep_model.Linear, quantized_cls=QuantLinear) | ||
| mtq.register(original_cls=deekseep_model.MLA, quantized_cls=QuantMLA) | ||
| mtq.register(original_cls=deekseep_model.MoE, quantized_cls=CalibMoe) | ||
|
|
||
|
|
||
| def load_deepseek_model(model_config: str, model_path: str, batch_size: int): | ||
|
|
@@ -243,10 +276,10 @@ def ptq( | |
| ## create dataset | ||
| device = next(model.parameters()).device | ||
| calib_dataset = get_dataset_dataloader( | ||
| dataset_name="cnn_dailymail", | ||
| dataset_name=["cnn_dailymail", "nemotron-post-training-dataset-v2"], | ||
| tokenizer=tokenizer, | ||
| batch_size=batch_size, | ||
| num_samples=calib_size, | ||
| num_samples=[calib_size, calib_size], | ||
| device=device, | ||
| ) | ||
|
|
||
|
|
@@ -307,6 +340,13 @@ def state_dict_filter(state_dict): | |
| os.path.join(output_path, f"amax_dict_rank{rank}-mp{world_size}.pt"), | ||
| ) | ||
|
|
||
| # if rank == 0: | ||
| # with open("expert_activation_counts.txt", "w") as f: | ||
| # for name, module in model.named_modules(): | ||
| # if isinstance(module, deekseep_model.MoE): | ||
| # counts = module.activated_expert_counts() | ||
| # f.writelines(f"{name}: {count}\n" for count in counts) | ||
|
|
||
| quant_config = get_quant_config(model.named_modules()) | ||
|
|
||
| if enable_fp8_kvcache: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a
cd ..before running the calibration command.The V3.2 setup block leaves us inside
DeepSeek-V3.2-Exp. If we run the calibration command as written from there, the path resolves toDeepSeek-V3.2-Exp/DeepSeek-V3.2-Exp/...and fails. Please add a step (e.g.,cd ..) after installing requirements so readers return to the project root before launching calibration.🤖 Prompt for AI Agents