Skip to content

Commit b275635

Browse files
authored
[https://nvbugs/5498478][fix] Fix eagle3 fp8 kv target model + bf16 draft model + chunked prefill (#8910)
Signed-off-by: Dylan Chen <[email protected]>
1 parent c73efe1 commit b275635

File tree

3 files changed

+130
-6
lines changed

3 files changed

+130
-6
lines changed

cpp/kernels/fmha_v2/setup.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3063,7 +3063,9 @@ def get_kernel_traits_code(specs_names):
30633063
# 2. Hopper sm89 with e4m3/e4m3_fp32 dtype uses cubins for accuracy regressions (will be fixed).
30643064
# You should set the condition `use_cubin_header` to false if you have modified the source codes of those kernels that use cubins.
30653065
# This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins.
3066-
def use_cubin_header(sm, head_size, dtype):
3066+
def use_cubin_header(sm, head_size, dtype, output_dtype=None):
3067+
if 'e4m3' in dtype and output_dtype in ['bf16', 'fp16']:
3068+
return False
30673069
return (sm == 90 and head_size == 128) or (sm == 89 and 'e4m3' in dtype)
30683070

30693071

@@ -3074,7 +3076,7 @@ def get_cubin_header(kernel_traits, specs_names):
30743076
cubin_lens_dict = {}
30753077
for kspec, fname, lname, kname in specs_names:
30763078
if generate_cu_trtllm and not use_cubin_header(
3077-
kspec.sm, kspec.head_size, kspec.dtype):
3079+
kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype):
30783080
continue
30793081
name = fname.replace('.', '_')
30803082
data = 'extern unsigned char cubin_{name}_cubin[];'.format(name=name)
@@ -3229,7 +3231,8 @@ def get_cubin_header(kernel_traits, specs_names):
32293231
if generate_cu_trtllm:
32303232

32313233
def get_lname_from_kname(kname: str) -> str:
3232-
if use_cubin_header(int(sm), int(head_size), prec.lower()):
3234+
if use_cubin_header(int(sm), int(head_size), prec.lower(),
3235+
output_prec.lower()):
32333236
return 'nullptr'
32343237
lname = kname.replace('_kernel', '')
32353238
mask_types = [
@@ -3248,8 +3251,9 @@ def get_lname_from_kname(kname: str) -> str:
32483251
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
32493252
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
32503253
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
3251-
'''.format(**locals()) if use_cubin_header(int(sm), int(head_size),
3252-
prec.lower()) else '''\
3254+
'''.format(**locals()) if use_cubin_header(int(sm),
3255+
int(head_size), prec.lower(),
3256+
output_prec.lower()) else '''\
32533257
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
32543258
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
32553259
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
@@ -3791,7 +3795,7 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
37913795
continue
37923796
# for normal attention, we do not need return softmax for ws fp8 kernels currently.
37933797
# also fp8 input and bf16 output is only needed for MLA kernel.
3794-
skip_combination = return_softmax or (output_dtype is not None)
3798+
skip_combination = return_softmax
37953799
# for context mla, we need separate qkv as input layout when returning softmax.
37963800
skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V
37973801
if not skip_combination:

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ bool DecoderXQAImplJIT::shouldUse(XQAParams const& umbrellaXQAParams, bool forCo
124124
bool hasPerfGain = mayHavePerfGain(xqaParams);
125125
if (!hasPerfGain)
126126
{
127+
if (!xqaParams.is_fp8_output && xqaParams.kv_cache_data_type == DATA_TYPE_E4M3
128+
&& (xqaParams.data_type == DATA_TYPE_BF16 || xqaParams.data_type == DATA_TYPE_FP16))
129+
{
130+
TLLM_LOG_DEBUG(
131+
"JIT XQA is selected in the generation phase for fp16/bf16 input and e4m3 kv cache because MMHA "
132+
"does not support this combination.");
133+
return true;
134+
}
127135
TLLM_LOG_DEBUG("JIT XQA is not used: maybe no performance gain");
128136
return false;
129137
}

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,5 +520,117 @@ def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool):
520520
llm_spec.shutdown()
521521

522522

523+
@pytest.mark.parametrize(
524+
"enable_block_reuse,use_one_model,enable_chunked_prefill,fp8_target", [
525+
[True, True, True, True],
526+
])
527+
@pytest.mark.high_cuda_memory
528+
def test_qwen3_eagle3(enable_block_reuse: bool, use_one_model: bool,
529+
enable_chunked_prefill: bool, fp8_target: bool):
530+
# Eagle3 one model works with overlap scheduler and block reuse.
531+
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
532+
if total_mem_gb < 35:
533+
pytest.skip("Not enough memory to load target + draft model")
534+
535+
use_cuda_graph = True
536+
attn_backend = "TRTLLM"
537+
disable_overlap_scheduler = False
538+
use_chain_drafter = True
539+
multi_batch = False
540+
attention_dp = False
541+
542+
models_path = llm_models_root()
543+
eagle_model_dir = f"{models_path}/Zhi-Create-Qwen3-32B-Eagle3"
544+
target_model_dir = f"{models_path}/Qwen3/Qwen3-32B"
545+
if fp8_target:
546+
target_model_dir = f"{models_path}/Qwen3/Qwen3-32B-FP8/"
547+
548+
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
549+
# that ref and spec does not match 100%
550+
max_batch_size = 4 if multi_batch else 1
551+
max_draft_len = 3
552+
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
553+
max_tokens=8192)
554+
if fp8_target:
555+
kv_cache_config.dtype = 'fp8'
556+
cuda_graph_config = CudaGraphConfig(
557+
batch_sizes=[i for i in range(1, max_batch_size +
558+
1)]) if use_cuda_graph else None
559+
560+
llm_common_config = dict(
561+
model=target_model_dir,
562+
attn_backend=attn_backend,
563+
disable_overlap_scheduler=disable_overlap_scheduler,
564+
cuda_graph_config=cuda_graph_config,
565+
max_batch_size=max_batch_size,
566+
kv_cache_config=kv_cache_config,
567+
enable_attention_dp=attention_dp,
568+
max_seq_len=8192,
569+
enable_chunked_prefill=enable_chunked_prefill,
570+
)
571+
if enable_chunked_prefill:
572+
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
573+
llm_common_config['max_num_tokens'] = 64
574+
575+
spec_config = EagleDecodingConfig(
576+
max_draft_len=max_draft_len,
577+
speculative_model_dir=eagle_model_dir,
578+
eagle3_one_model=use_one_model,
579+
)
580+
spec_config._allow_chain_drafter = use_chain_drafter
581+
582+
# Create the LLM instance
583+
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
584+
585+
# Acceptance rate tests
586+
if enable_chunked_prefill:
587+
# Use a long prompt for chunked prefill tests.
588+
prompts = [
589+
"The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and "
590+
]
591+
tok_ids = [llm_spec.tokenizer.encode(prompts[0])]
592+
else:
593+
prompts = [
594+
"The capital of France is",
595+
"The president of the United States is",
596+
]
597+
tok_ids = [llm_spec.tokenizer.encode("The future of AI is")]
598+
if multi_batch:
599+
tok_ids.append(llm_spec.tokenizer.encode(prompts))
600+
601+
sampling_params = SamplingParams(max_tokens=128, temperature=0)
602+
for i in range(len(tok_ids)):
603+
num_tokens = 0
604+
num_drafted = 0
605+
num_accepted = 0
606+
607+
for output in llm_spec.generate_async(tok_ids[i],
608+
sampling_params,
609+
streaming=True):
610+
new_tokens = output.outputs[0].token_ids
611+
num_drafted += max_draft_len
612+
num_accepted += len(new_tokens) - num_tokens - 1
613+
num_tokens = len(new_tokens)
614+
615+
accept_rate = num_accepted / num_drafted
616+
assert accept_rate > 0.10
617+
618+
# Output tests
619+
sampling_params = SamplingParams(max_tokens=10, temperature=0)
620+
621+
results_spec = llm_spec.generate(prompts, sampling_params)
622+
generated_text_spec = [result.outputs[0].text for result in results_spec]
623+
llm_spec.shutdown()
624+
625+
llm_ref = LLM(**llm_common_config)
626+
results_ref = llm_ref.generate(prompts, sampling_params)
627+
generated_text_ref = [result.outputs[0].text for result in results_ref]
628+
llm_ref.shutdown()
629+
630+
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
631+
# The spec decode algorithm currently guarantees identical results
632+
assert text_spec == text_ref
633+
634+
523635
if __name__ == "__main__":
524636
unittest.main()

0 commit comments

Comments
 (0)