55
66import torch
77
8- from ..._utils import get_sm_version
98from ..attention_backend .trtllm import AttentionBackend , TrtllmAttention
109from ..pyexecutor .resource_manager import BaseResourceManager
1110
@@ -117,17 +116,11 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
117116 # 1-model has separate logic for handling draft tokens
118117 return False
119118
120- if issubclass (attention_backend ,
121- TrtllmAttention ) and self .is_mtp_eagle ():
122- # TRTLLM MLA does not work with the chunked context mode.
123- return False
124-
125- return not issubclass (attention_backend ,
126- TrtllmAttention ) or get_sm_version () != 100
119+ return not issubclass (attention_backend , TrtllmAttention )
127120
128121 def attention_need_spec_dec_mode (
129122 self ,
130- spec_resource_manager : BaseResourceManager ,
123+ spec_resource_manager : Optional [ BaseResourceManager ] ,
131124 is_draft_model : bool ,
132125 attention_backend : Type [AttentionBackend ],
133126 use_chain_drafter : bool ,
@@ -137,9 +130,10 @@ def attention_need_spec_dec_mode(
137130 If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
138131 """
139132 is_trtllm_attention = issubclass (attention_backend , TrtllmAttention )
140- return self .is_eagle3_one_model () or (
133+
134+ return self .is_eagle3_one_model () or not is_draft_model or (
141135 self .is_eagle3 () and spec_resource_manager .is_first_draft
142- and is_trtllm_attention and use_chain_drafter and is_draft_model )
136+ and is_trtllm_attention )
143137
144138 @staticmethod
145139 def from_string (name : Optional [str ]) -> "SpeculativeDecodingMode" :
0 commit comments