Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 164 additions & 37 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .cuda_graph_runner import CUDAGraphRunner
from .guided_decoder import CapturableGuidedDecoder
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
from .llm_request import get_draft_token_length
from .llm_request import LlmRequest, get_draft_token_length
from .model_loader import ModelLoader
from .resource_manager import (BaseResourceManager, KVCacheManager,
ResourceManager, ResourceManagerType)
Expand All @@ -72,14 +72,13 @@ def get_max_num_sequences(self) -> int:
raise NotImplementedError

@abstractmethod
def forward(
self,
scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager,
new_tensors_device: Optional[SampleStateTensors],
gather_context_logits: bool = False,
cache_indirection_buffer: Optional[torch.Tensor] = None,
):
def forward(self,
scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager,
new_tensors_device: Optional[SampleStateTensors],
gather_context_logits: bool = False,
cache_indirection_buffer: Optional[torch.Tensor] = None,
num_accepted_tokens_device: Optional[torch.Tensor] = None):
raise NotImplementedError

def warmup(self, resource_manager: ResourceManager) -> None:
Expand Down Expand Up @@ -492,6 +491,7 @@ def _run_torch_compile_warmup(self, resource_manager: ResourceManager):
return

logger.info("Running torch.compile warmup...")

kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
curr_max_num_tokens = min(
Expand Down Expand Up @@ -538,13 +538,21 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager):
self.batch_size * (self.max_seq_len - 1))

cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None)
#TODO fxiong: Called here?
with self.no_cuda_graph(), autotune(cache_path=cache_path,
rank=self.mapping.rank):
warmup_request = self._create_warmup_request(
resource_manager, curr_max_num_tokens, 0)
with self._release_batch_context(warmup_request,
resource_manager) as batch:
if batch is not None:
# Reset the flag is_first_draft for the draft model.
# This is necessary for overlap scheduler.
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
if self.is_draft_model and isinstance(
spec_resource_manager, Eagle3ResourceManager):
spec_resource_manager.is_first_draft = True
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
Expand Down Expand Up @@ -633,6 +641,7 @@ def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager):
return

logger.info("Running piecewise CUDA graph warmup...")

piecewise_cuda_graph_num_tokens = sorted(
self._piecewise_cuda_graph_num_tokens, reverse=True)

Expand Down Expand Up @@ -1182,10 +1191,13 @@ def _prepare_tp_inputs(
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata] = None,
new_tensors_device: Optional[SampleStateTensors] = None,
cache_indirection_buffer: Optional[torch.Tensor] = None):
cache_indirection_buffer: Optional[torch.Tensor] = None,
num_accepted_tokens_device: Optional[torch.Tensor] = None,
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None):
"""
Prepare inputs for Pytorch Model.
"""

new_tokens_device, new_tokens_lens_device, next_draft_tokens_device = None, None, None
if new_tensors_device is not None:
# speculative decoding cases: [batch, 1 + draft_len], others: [batch]
Expand Down Expand Up @@ -1216,6 +1228,20 @@ def _prepare_tp_inputs(
mrope_position_ids = []
num_accepted_draft_tokens = [] # per request

# Track first_draft_requests info for GPU computation (draft model only)
first_draft_base_gather_ids = [
] # Base values for gather_ids computation
first_draft_seq_slots = [
] # seq_slots to index into num_accepted_tokens_device
first_draft_request_indices = [
] # Indices in the num_accepted_draft_tokens list

# Track positions in input_ids for GPU updates (draft model only)
context_input_ids_positions = [
] # (start_idx, end_idx, seq_slot) for context requests
first_draft_input_ids_positions = [
] # (start_idx, end_idx, seq_slot) for first_draft requests

for request in scheduled_requests.context_requests:
request_ids.append(request.py_request_id)
all_prompt_tokens = request.get_tokens(0)
Expand All @@ -1225,7 +1251,20 @@ def _prepare_tp_inputs(
prompt_tokens = all_prompt_tokens[begin_compute:end_compute]
position_ids.extend(
range(begin_compute, begin_compute + len(prompt_tokens)))
input_ids.extend(prompt_tokens)

# Track position for GPU update (draft model only)
if self.is_draft_model and num_accepted_tokens_device is not None:
start_idx = len(input_ids)
input_ids.extend(prompt_tokens)
end_idx = len(input_ids)
slot_idx = req_id_to_old_request[
request.py_request_id].py_seq_slot
context_input_ids_positions.append(
(start_idx, end_idx - 1,
slot_idx)) # end_idx-1 is the last token position
else:
input_ids.extend(prompt_tokens)

gather_ids.append(len(input_ids) - 1)
sequence_lengths.append(len(prompt_tokens))
num_accepted_draft_tokens.append(len(prompt_tokens) - 1)
Expand Down Expand Up @@ -1369,14 +1408,9 @@ def _prepare_tp_inputs(
previous_batch_indices.append(previous_batch_idx)
previous_pos_indices.extend([previous_batch_idx] *
(1 + self.runtime_draft_len))
if self.spec_config.spec_dec_mode.has_draft_model():
# In the overlap scheduler workflow, if having draft model, we already updated the previous batch before launching the target model,
# so we only need to add the runtime_draft_len to the past_seen_token_num.
num_cached_tokens_per_seq.append(past_seen_token_num +
self.runtime_draft_len)
else:
num_cached_tokens_per_seq.append(past_seen_token_num +
self.runtime_draft_len + 1)

num_cached_tokens_per_seq.append(past_seen_token_num +
self.runtime_draft_len + 1)
request.cached_tokens = num_cached_tokens_per_seq[-1]
if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx(
self.attn_backend):
Expand All @@ -1387,20 +1421,47 @@ def _prepare_tp_inputs(
for request in first_draft_requests:
request_ids.append(request.py_request_id)
all_prompt_tokens = request.get_tokens(0)

draft_lens.append(0)
begin_compute = len(
all_prompt_tokens) - self.original_max_draft_len - 1
end_compute = begin_compute + self.original_max_draft_len + 1
prompt_tokens = all_prompt_tokens[begin_compute:end_compute]
position_ids.extend(
range(begin_compute, begin_compute + len(prompt_tokens)))
input_ids.extend(prompt_tokens)
gather_ids.append(
len(input_ids) - 1 - (self.original_max_draft_len -
request.py_num_accepted_draft_tokens))

# Track position for GPU update (draft model only)
if self.is_draft_model and num_accepted_tokens_device is not None:
start_idx = len(input_ids)
input_ids.extend(prompt_tokens)
end_idx = len(input_ids)
# For first_draft, we need to replace the last original_max_draft_len+1 tokens
slot_idx = req_id_to_old_request[
request.py_request_id].py_seq_slot
first_draft_input_ids_positions.append(
(start_idx, end_idx, slot_idx))

# Store info for GPU computation of gather_ids and num_accepted_draft_tokens
base_gather_id = len(
input_ids) - 1 - self.original_max_draft_len
gather_ids.append(
base_gather_id) # Placeholder, will be corrected on GPU
first_draft_base_gather_ids.append(base_gather_id)
first_draft_seq_slots.append(slot_idx)
first_draft_request_indices.append(
len(num_accepted_draft_tokens))

num_accepted_draft_tokens.append(
0) # Placeholder, will be corrected on GPU
else:
input_ids.extend(prompt_tokens)
gather_ids.append(
len(input_ids) - 1 - (self.original_max_draft_len -
request.py_num_accepted_draft_tokens))
num_accepted_draft_tokens.append(
request.py_num_accepted_draft_tokens)

sequence_lengths.append(1 + self.original_max_draft_len)
num_accepted_draft_tokens.append(
request.py_num_accepted_draft_tokens)
prompt_lengths.append(request.py_prompt_len)
past_seen_token_num = begin_compute
num_cached_tokens_per_seq.append(past_seen_token_num)
Expand All @@ -1420,7 +1481,17 @@ def _prepare_tp_inputs(
# skip adding input_ids of CUDA graph dummy requests so that new_tokens_device
# can be aligned to the correct positions.
if not request.is_cuda_graph_dummy:
input_ids.append(request.get_last_tokens(beam))
# Track position for GPU update (draft model only)
if self.is_draft_model and num_accepted_tokens_device is not None:
start_idx = len(input_ids)
input_ids.append(request.get_last_tokens(beam))
end_idx = len(input_ids)
slot_idx = req_id_to_old_request[
request.py_request_id].py_seq_slot
first_draft_input_ids_positions.append(
(start_idx, end_idx, slot_idx))
else:
input_ids.append(request.get_last_tokens(beam))
past_seen_token_num = request.max_beam_num_tokens - 1
else:
# the request has previous tensor
Expand Down Expand Up @@ -1495,6 +1566,30 @@ def previous_seq_slots_device():
dtype=torch.int,
pin_memory=True)
self.input_ids_cuda[:num_tokens].copy_(input_ids, non_blocking=True)

# Update input_ids_cuda with new tokens from new_tensors_device (draft model only)
if self.is_draft_model and num_accepted_tokens_device is not None:
# For context requests: replace the last token with new_tensors_device[0, seq_slot, 0]
if len(context_input_ids_positions) > 0:
for start_idx, last_token_idx, seq_slot in context_input_ids_positions:
# Update the last token position with the new token
self.input_ids_cuda[
last_token_idx] = new_tensors_device.new_tokens[
0, seq_slot, 0]

# For first_draft requests: replace the last (original_max_draft_len+1) tokens
# with new_tensors_device[:, seq_slot, 0]
if len(first_draft_input_ids_positions) > 0:
for start_idx, end_idx, seq_slot in first_draft_input_ids_positions:
# The range is [start_idx, end_idx), which is original_max_draft_len+1 tokens
num_tokens_to_replace = end_idx - start_idx
self.input_ids_cuda[
start_idx:
end_idx] = new_tensors_device.new_tokens[:
num_tokens_to_replace,
seq_slot,
0]

if num_draft_tokens > 0:
draft_tokens = torch.tensor(draft_tokens,
dtype=torch.int,
Expand All @@ -1508,6 +1603,19 @@ def previous_seq_slots_device():
self.num_accepted_draft_tokens_cuda[:len(
num_accepted_draft_tokens)].copy_(num_accepted_draft_tokens,
non_blocking=True)

# Update num_accepted_draft_tokens_cuda for first_draft_requests directly from num_accepted_tokens_device (draft model only)
if self.is_draft_model and len(first_draft_seq_slots) > 0:
first_draft_seq_slots_tensor = torch.tensor(
first_draft_seq_slots, dtype=torch.int, device='cuda')
first_draft_indices_tensor = torch.tensor(
first_draft_request_indices, dtype=torch.int, device='cuda')
# Extract accepted tokens for first_draft requests from device tensor
accepted_tokens = num_accepted_tokens_device[
first_draft_seq_slots_tensor]
# Update the correct positions in num_accepted_draft_tokens_cuda
self.num_accepted_draft_tokens_cuda[
first_draft_indices_tensor] = accepted_tokens
if next_draft_tokens_device is not None:
# Initialize these two values to zeros
self.previous_pos_id_offsets_cuda *= 0
Expand Down Expand Up @@ -1613,6 +1721,20 @@ def previous_seq_slots_device():
gather_ids, dtype=torch.int, pin_memory=True),
non_blocking=True)

# Update gather_ids for first_draft_requests on GPU (draft model only)
if self.is_draft_model and len(first_draft_seq_slots) > 0:
first_draft_seq_slots_tensor = torch.tensor(
first_draft_seq_slots, dtype=torch.int, device='cuda')
first_draft_indices_tensor = torch.tensor(
first_draft_request_indices, dtype=torch.int, device='cuda')
# Extract accepted tokens for first_draft requests from device tensor
accepted_tokens = num_accepted_tokens_device[
first_draft_seq_slots_tensor]
# Update gather_ids: gather_id = base_gather_id + num_accepted_tokens
# (since gather_id = len(input_ids) - 1 - (max_draft_len - num_accepted))
self.gather_ids_cuda[
first_draft_indices_tensor] += accepted_tokens

if not attn_metadata.is_cuda_graph:
# Assumes seq lens do not change between CUDA graph invocations. This applies
# to draft sequences too. This means that all draft sequences must be padded.
Expand Down Expand Up @@ -2211,7 +2333,9 @@ def _prepare_inputs(
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata] = None,
new_tensors_device: Optional[SampleStateTensors] = None,
cache_indirection_buffer: Optional[torch.Tensor] = None):
cache_indirection_buffer: Optional[torch.Tensor] = None,
num_accepted_tokens_device: Optional[torch.Tensor] = None,
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None):
if self.mapping is not None and 'cp_type' in self.mapping.cp_config:
cp_type = self.mapping.cp_config['cp_type']
if CpType.STAR == cp_type:
Expand All @@ -2227,18 +2351,20 @@ def _prepare_inputs(
return self._prepare_tp_inputs(scheduled_requests, kv_cache_manager,
attn_metadata, spec_metadata,
new_tensors_device,
cache_indirection_buffer)
cache_indirection_buffer,
num_accepted_tokens_device,
req_id_to_old_request)

@torch.inference_mode()
@with_model_extra_attrs(lambda self: self.model.extra_attrs)
def forward(
self,
scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager,
new_tensors_device: Optional[SampleStateTensors] = None,
gather_context_logits: bool = False,
cache_indirection_buffer: Optional[torch.Tensor] = None,
):
def forward(self,
scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager,
new_tensors_device: Optional[SampleStateTensors] = None,
gather_context_logits: bool = False,
cache_indirection_buffer: Optional[torch.Tensor] = None,
num_accepted_tokens_device: Optional[torch.Tensor] = None,
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None):
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)

Expand Down Expand Up @@ -2293,7 +2419,8 @@ def forward(

inputs, gather_ids = self._prepare_inputs(
padded_requests, kv_cache_manager, attn_metadata, spec_metadata,
new_tensors_device, cache_indirection_buffer)
new_tensors_device, cache_indirection_buffer,
num_accepted_tokens_device, req_id_to_old_request)

self.iter_counter += 1
with with_shared_pool(self.cuda_graph_runner.get_graph_pool()):
Expand Down
Loading