55
66import torch
77
8- from ..._utils import get_sm_version
98from ..attention_backend .trtllm import AttentionBackend , TrtllmAttention
109from ..pyexecutor .resource_manager import BaseResourceManager
1110
@@ -117,13 +116,7 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
117116 # 1-model has separate logic for handling draft tokens
118117 return False
119118
120- if issubclass (attention_backend ,
121- TrtllmAttention ) and self .is_mtp_eagle ():
122- # TRTLLM MLA does not work with the chunked context mode.
123- return False
124-
125- return not issubclass (attention_backend ,
126- TrtllmAttention ) or get_sm_version () != 100
119+ return not issubclass (attention_backend , TrtllmAttention )
127120
128121 def attention_need_spec_dec_mode (
129122 self ,
@@ -137,9 +130,8 @@ def attention_need_spec_dec_mode(
137130 If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
138131 """
139132 is_trtllm_attention = issubclass (attention_backend , TrtllmAttention )
140- return self .is_eagle3_one_model () or (
141- self .is_eagle3 () and spec_resource_manager .is_first_draft
142- and is_trtllm_attention and use_chain_drafter and is_draft_model )
133+ return self .is_eagle3_one_model () or not is_draft_model or (
134+ spec_resource_manager .is_first_draft and is_trtllm_attention )
143135
144136 @staticmethod
145137 def from_string (name : Optional [str ]) -> "SpeculativeDecodingMode" :
0 commit comments