|
12 | 12 | from transformers.utils import HF_MODULES_CACHE |
13 | 13 |
|
14 | 14 | from tensorrt_llm import logger |
15 | | -from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid |
| 15 | +from tensorrt_llm._torch.pyexecutor.config_utils import (is_nemotron_hybrid, |
| 16 | + load_pretrained_config) |
16 | 17 | from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding |
17 | 18 | from tensorrt_llm.bindings import LayerType as LayerTypeCpp |
18 | 19 | from tensorrt_llm.functional import AllReduceStrategy |
|
25 | 26 | TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) |
26 | 27 |
|
27 | 28 |
|
28 | | -class LazyConfigDict(dict): |
29 | | - |
30 | | - def __getitem__(self, key): |
31 | | - import tensorrt_llm._torch.configs as configs |
32 | | - return getattr(configs, super().__getitem__(key)) |
33 | | - |
34 | | - |
35 | | -_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict( |
36 | | - deepseek_v32="DeepseekV3Config", |
37 | | -) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class |
38 | | - |
39 | | - |
40 | 29 | @dataclass |
41 | 30 | class MoeLoadBalancerConfig: |
42 | 31 | num_slots: Optional[int] = None |
@@ -432,51 +421,31 @@ def from_pretrained(cls, |
432 | 421 | # When handling the case where model_format is TLLM_ENGINE |
433 | 422 | # send cyclic requests to the NONE URL. |
434 | 423 | if checkpoint_dir is not None: |
435 | | - config_dict, _ = transformers.PretrainedConfig.get_config_dict( |
| 424 | + pretrained_config = load_pretrained_config( |
436 | 425 | checkpoint_dir, |
| 426 | + trust_remote_code=trust_remote_code, |
437 | 427 | **kwargs, |
438 | 428 | ) |
439 | | - model_type = config_dict.get("model_type") |
440 | | - if model_type in _CONFIG_REGISTRY: |
441 | | - config_class = _CONFIG_REGISTRY[model_type] |
442 | | - pretrained_config = config_class.from_pretrained( |
443 | | - checkpoint_dir, |
444 | | - **kwargs, |
445 | | - ) |
446 | | - if model_type == "deepseek_v32": |
447 | | - sparse_attention_config = kwargs.get( |
448 | | - 'sparse_attention_config') |
449 | | - kwargs[ |
450 | | - 'sparse_attention_config'] = DeepSeekSparseAttentionConfig( |
451 | | - index_n_heads=( |
452 | | - sparse_attention_config.index_n_heads |
453 | | - if sparse_attention_config |
454 | | - and sparse_attention_config.index_n_heads |
455 | | - is not None else |
456 | | - pretrained_config.index_n_heads), |
457 | | - index_head_dim=( |
458 | | - sparse_attention_config.index_head_dim |
459 | | - if sparse_attention_config |
460 | | - and sparse_attention_config.index_head_dim |
461 | | - is not None else |
462 | | - pretrained_config.index_head_dim), |
463 | | - index_topk=(sparse_attention_config.index_topk |
464 | | - if sparse_attention_config and |
465 | | - sparse_attention_config.index_topk |
466 | | - is not None else |
467 | | - pretrained_config.index_topk), |
468 | | - indexer_max_chunk_size=( |
469 | | - sparse_attention_config. |
470 | | - indexer_max_chunk_size |
471 | | - if sparse_attention_config |
472 | | - and sparse_attention_config. |
473 | | - indexer_max_chunk_size is not None else |
474 | | - None)) |
475 | | - else: |
476 | | - pretrained_config = transformers.AutoConfig.from_pretrained( |
477 | | - checkpoint_dir, |
478 | | - trust_remote_code=trust_remote_code, |
479 | | - ) |
| 429 | + if pretrained_config.architectures[ |
| 430 | + 0] == "DeepseekV32ForCausalLM": |
| 431 | + sparse_attention_config = kwargs.get( |
| 432 | + 'sparse_attention_config') |
| 433 | + if sparse_attention_config: |
| 434 | + index_n_heads = sparse_attention_config.index_n_heads or pretrained_config.index_n_heads |
| 435 | + index_head_dim = sparse_attention_config.index_head_dim or pretrained_config.index_head_dim |
| 436 | + index_topk = sparse_attention_config.index_topk or pretrained_config.index_topk |
| 437 | + indexer_max_chunk_size = sparse_attention_config.indexer_max_chunk_size |
| 438 | + else: |
| 439 | + index_n_heads = pretrained_config.index_n_heads |
| 440 | + index_head_dim = pretrained_config.index_head_dim |
| 441 | + index_topk = pretrained_config.index_topk |
| 442 | + indexer_max_chunk_size = None |
| 443 | + kwargs[ |
| 444 | + 'sparse_attention_config'] = DeepSeekSparseAttentionConfig( |
| 445 | + index_n_heads=index_n_heads, |
| 446 | + index_head_dim=index_head_dim, |
| 447 | + index_topk=index_topk, |
| 448 | + indexer_max_chunk_size=indexer_max_chunk_size) |
480 | 449 | else: |
481 | 450 | raise ValueError( |
482 | 451 | "checkpoint_dir is None. Cannot load model config without a valid checkpoint directory." |
|
0 commit comments