Skip to content

Commit 9f2a3ae

Browse files
authored
[None][fix] Restrict tinygemm use to certain SMs (#8182)
Signed-off-by: Dongfeng Yu <[email protected]> Signed-off-by: dongfengy <[email protected]>
1 parent ed8e00a commit 9f2a3ae

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

cpp/tensorrt_llm/thop/tinygemm2.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ namespace torch_ext
3030
{
3131
torch::Tensor tinygemm2_forward(torch::Tensor input, torch::Tensor weight, torch::Tensor bias)
3232
{
33+
auto const smVersion = tensorrt_llm::common::getSMVersion();
34+
TORCH_CHECK(
35+
smVersion == 90 || smVersion == 100 || smVersion == 103, "tinygemm2 only supports SM90, SM100, and SM103.");
3336
TORCH_CHECK(input.dim() == 2, "input must be 2D");
3437
TORCH_CHECK(weight.dim() == 2, "weight must be 2D");
3538
TORCH_CHECK(bias.dim() == 1, "bias must be 1D");

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tqdm import tqdm
88
from transformers import GptOssConfig
99

10+
from tensorrt_llm._utils import get_sm_version
1011
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
1112

1213
from ..attention_backend import AttentionMetadata
@@ -225,7 +226,9 @@ def _create_ideal_expert_load_balanced_logits(
225226
dtype=pretrained_config.torch_dtype)
226227

227228
def compute_gate_output(self, x: torch.Tensor) -> torch.Tensor:
228-
if x.shape[0] <= MIN_LATENCY_TINYGEMM_NUM_TOKENS:
229+
if get_sm_version() in [
230+
90, 100, 103
231+
] and x.shape[0] <= MIN_LATENCY_TINYGEMM_NUM_TOKENS:
229232
weight = self.gate.weight
230233
bias = self.gate.bias
231234
g = torch.ops.trtllm.tinygemm2(x, weight, bias)

0 commit comments

Comments
 (0)