Skip to content

Commit a21697e

Browse files
authored
[None][fix] fix config loading for DeepSeek-V3.2 in trtllm-bench (#8729)
Signed-off-by: Fanrong Li <[email protected]>
1 parent e2c5a38 commit a21697e

File tree

5 files changed

+76
-86
lines changed

5 files changed

+76
-86
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 24 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from transformers.utils import HF_MODULES_CACHE
1313

1414
from tensorrt_llm import logger
15-
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
15+
from tensorrt_llm._torch.pyexecutor.config_utils import (is_nemotron_hybrid,
16+
load_pretrained_config)
1617
from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding
1718
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
1819
from tensorrt_llm.functional import AllReduceStrategy
@@ -25,18 +26,6 @@
2526
TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)
2627

2728

28-
class LazyConfigDict(dict):
29-
30-
def __getitem__(self, key):
31-
import tensorrt_llm._torch.configs as configs
32-
return getattr(configs, super().__getitem__(key))
33-
34-
35-
_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict(
36-
deepseek_v32="DeepseekV3Config",
37-
) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class
38-
39-
4029
@dataclass
4130
class MoeLoadBalancerConfig:
4231
num_slots: Optional[int] = None
@@ -432,51 +421,31 @@ def from_pretrained(cls,
432421
# When handling the case where model_format is TLLM_ENGINE
433422
# send cyclic requests to the NONE URL.
434423
if checkpoint_dir is not None:
435-
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
424+
pretrained_config = load_pretrained_config(
436425
checkpoint_dir,
426+
trust_remote_code=trust_remote_code,
437427
**kwargs,
438428
)
439-
model_type = config_dict.get("model_type")
440-
if model_type in _CONFIG_REGISTRY:
441-
config_class = _CONFIG_REGISTRY[model_type]
442-
pretrained_config = config_class.from_pretrained(
443-
checkpoint_dir,
444-
**kwargs,
445-
)
446-
if model_type == "deepseek_v32":
447-
sparse_attention_config = kwargs.get(
448-
'sparse_attention_config')
449-
kwargs[
450-
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
451-
index_n_heads=(
452-
sparse_attention_config.index_n_heads
453-
if sparse_attention_config
454-
and sparse_attention_config.index_n_heads
455-
is not None else
456-
pretrained_config.index_n_heads),
457-
index_head_dim=(
458-
sparse_attention_config.index_head_dim
459-
if sparse_attention_config
460-
and sparse_attention_config.index_head_dim
461-
is not None else
462-
pretrained_config.index_head_dim),
463-
index_topk=(sparse_attention_config.index_topk
464-
if sparse_attention_config and
465-
sparse_attention_config.index_topk
466-
is not None else
467-
pretrained_config.index_topk),
468-
indexer_max_chunk_size=(
469-
sparse_attention_config.
470-
indexer_max_chunk_size
471-
if sparse_attention_config
472-
and sparse_attention_config.
473-
indexer_max_chunk_size is not None else
474-
None))
475-
else:
476-
pretrained_config = transformers.AutoConfig.from_pretrained(
477-
checkpoint_dir,
478-
trust_remote_code=trust_remote_code,
479-
)
429+
if pretrained_config.architectures[
430+
0] == "DeepseekV32ForCausalLM":
431+
sparse_attention_config = kwargs.get(
432+
'sparse_attention_config')
433+
if sparse_attention_config:
434+
index_n_heads = sparse_attention_config.index_n_heads or pretrained_config.index_n_heads
435+
index_head_dim = sparse_attention_config.index_head_dim or pretrained_config.index_head_dim
436+
index_topk = sparse_attention_config.index_topk or pretrained_config.index_topk
437+
indexer_max_chunk_size = sparse_attention_config.indexer_max_chunk_size
438+
else:
439+
index_n_heads = pretrained_config.index_n_heads
440+
index_head_dim = pretrained_config.index_head_dim
441+
index_topk = pretrained_config.index_topk
442+
indexer_max_chunk_size = None
443+
kwargs[
444+
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
445+
index_n_heads=index_n_heads,
446+
index_head_dim=index_head_dim,
447+
index_topk=index_topk,
448+
indexer_max_chunk_size=indexer_max_chunk_size)
480449
else:
481450
raise ValueError(
482451
"checkpoint_dir is None. Cannot load model config without a valid checkpoint directory."

tensorrt_llm/_torch/pyexecutor/config_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import transformers
2+
3+
14
def is_nemotron_hybrid(config):
25
if hasattr(config, "hybrid_override_pattern"
36
) and config.hybrid_override_pattern is not None and len(
@@ -18,3 +21,32 @@ def is_qwen3_next(config):
1821
config, 'architectures'
1922
) and config.architectures is not None and config.architectures[
2023
0] == 'Qwen3NextForCausalLM'
24+
25+
26+
# TODO: remove this once the transformers can support all of those models in _CONFIG_REGISTRY
27+
class LazyConfigDict(dict):
28+
29+
def __getitem__(self, key):
30+
import tensorrt_llm._torch.configs as configs
31+
return getattr(configs, super().__getitem__(key))
32+
33+
34+
_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict(
35+
deepseek_v32="DeepseekV3Config",
36+
) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class
37+
38+
39+
def load_pretrained_config(model_name_or_path: str,
40+
trust_remote_code: bool = False,
41+
**kwargs) -> transformers.PretrainedConfig:
42+
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
43+
model_name_or_path, **kwargs)
44+
model_type = config_dict.get("model_type")
45+
if model_type in _CONFIG_REGISTRY:
46+
config_class = _CONFIG_REGISTRY[model_type]
47+
model_config = config_class.from_pretrained(model_name_or_path,
48+
**kwargs)
49+
else:
50+
model_config = transformers.AutoConfig.from_pretrained(
51+
model_name_or_path, trust_remote_code=trust_remote_code)
52+
return model_config

tensorrt_llm/bench/build/build.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from __future__ import annotations
2-
from transformers import AutoConfig
32

43
from pathlib import Path
54
from typing import Tuple, get_args
65
import click
76
from click_option_group import AllOptionGroup, optgroup
87

9-
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
8+
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid, load_pretrained_config
109
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
1110
from tensorrt_llm.bench.utils.data import create_dataset_from_stream, initialize_tokenizer
1211
from tensorrt_llm.bench.utils import VALID_QUANT_ALGOS
@@ -86,9 +85,9 @@ def get_model_config(model_name: str, model_path: Path = None) -> ModelConfig:
8685
Raises:
8786
ValueError: When model is not supported.
8887
"""
89-
if is_nemotron_hybrid(
90-
AutoConfig.from_pretrained(model_path or model_name,
91-
trust_remote_code=True)):
88+
pretrained_config = load_pretrained_config(model_path or model_name,
89+
trust_remote_code=True)
90+
if is_nemotron_hybrid(pretrained_config):
9291
return NemotronHybridConfig.from_hf(model_name, model_path)
9392
return ModelConfig.from_hf(model_name, model_path)
9493

tensorrt_llm/bench/build/dataclasses.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from transformers import AutoConfig
21
from typing import Optional, Literal
32
from pydantic import AliasPath, BaseModel, Field, AliasChoices, model_validator
43
import huggingface_hub
@@ -14,6 +13,8 @@
1413
import json
1514
import struct
1615

16+
from tensorrt_llm._torch.pyexecutor.config_utils import load_pretrained_config
17+
1718

1819
def parse_safetensors_file_metadata(model_path, filename):
1920

@@ -192,9 +193,10 @@ def get_param_count(cls, model_hf_name, hf_model_path):
192193

193194
@classmethod
194195
def from_hf(cls, model_hf_name, hf_model_path):
195-
model_name_or_path = hf_model_path or model_hf_name
196-
hf_config = AutoConfig.from_pretrained(
197-
model_name_or_path, trust_remote_code=True).to_dict()
196+
pretrained_config = load_pretrained_config(hf_model_path
197+
or model_hf_name,
198+
trust_remote_code=True)
199+
hf_config = pretrained_config.to_dict()
198200
param_count = cls.get_param_count(model_hf_name, hf_model_path)
199201

200202
return cls(name=model_hf_name, param_count=param_count, **hf_config)

tensorrt_llm/serve/openai_server.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from fastapi.exceptions import RequestValidationError
1818
from fastapi.responses import JSONResponse, Response, StreamingResponse
1919
from starlette.routing import Mount
20-
from transformers import AutoConfig, AutoProcessor
20+
from transformers import AutoProcessor
2121

2222
from tensorrt_llm._tensorrt_engine import LLM
2323
# yapf: disable
@@ -101,27 +101,15 @@ def __init__(self,
101101
except Exception:
102102
logger.debug("Failed to load AutoProcessor or AutoConfig for %s", hf_tokenizer_path)
103103
self.processor = None
104-
# Temporary workaround for DSv3.2 config.
105-
import transformers
106-
107-
from tensorrt_llm._torch.model_config import _CONFIG_REGISTRY
108-
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
109-
hf_tokenizer_path,
110-
trust_remote_code=trust_remote_code
111-
)
112-
model_type = config_dict.get("model_type")
113-
if model_type in _CONFIG_REGISTRY:
114-
config_class = _CONFIG_REGISTRY[model_type]
115-
self.model_config = config_class.from_pretrained(
116-
hf_tokenizer_path,
117-
trust_remote_code=trust_remote_code
118-
)
119-
else:
120-
try:
121-
self.model_config = AutoConfig.from_pretrained(hf_tokenizer_path, trust_remote_code=trust_remote_code)
122-
except Exception:
123-
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
124-
self.model_config = None
104+
# load model config
105+
try:
106+
from tensorrt_llm._torch.pyexecutor.config_utils import \
107+
load_pretrained_config
108+
self.model_config = load_pretrained_config(hf_tokenizer_path,
109+
trust_remote_code=trust_remote_code)
110+
except Exception:
111+
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
112+
self.model_config = None
125113

126114
# Enable response storage for Responses API
127115
self.enable_store = True

0 commit comments

Comments
 (0)