Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ transforms:
############################################################################################
build_model:
stage: factory
run_per_gm: false
device: meta
# nothing to clean up
run_graph_cleanup: false
Expand All @@ -14,8 +15,8 @@ transforms:
stage: export
clone_state_dict: false
strict: false
# nothing to clean up
run_graph_cleanup: false
run_per_gm: false
run_graph_cleanup: true
requires_clean_graph: false
cleanup_noop_slice:
stage: post_export
Expand All @@ -35,6 +36,7 @@ transforms:
run_shape_prop: true
match_eager_attention:
stage: pattern_matcher
requires_shape_prop: true
match_grouped_attention:
stage: pattern_matcher
match_attention_layout:
Expand Down Expand Up @@ -87,8 +89,10 @@ transforms:
############################################################################################
load_weights:
stage: weight_load
run_per_gm: false
move_inputs_to_device:
stage: weight_load
run_per_gm: false
############################################################################################
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
############################################################################################
Expand Down Expand Up @@ -138,10 +142,13 @@ transforms:
attn_backend: cuda_causal_conv
initialize_cache:
stage: cache_init
run_per_gm: false
resize_kv_cache:
stage: cache_init
run_per_gm: false
############################################################################################
# COMPILE MODEL
############################################################################################
compile_model:
stage: compile
run_per_gm: false
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,29 @@ transforms:
############################################################################################
build_and_load_factory_model:
stage: factory
run_per_gm: false
############################################################################################
# MOVE ARGUMENTS TO DEVICE
############################################################################################
move_inputs_to_device:
stage: weight_load
run_per_gm: false
############################################################################################
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
############################################################################################
detect_hf_attn_layers:
stage: cache_init
run_per_gm: false
transformers_replace_cached_attn:
stage: cache_init
attn_backend: flashinfer
run_per_gm: false
initialize_cache:
stage: cache_init
run_per_gm: false
resize_kv_cache:
stage: cache_init
run_per_gm: false
############################################################################################
# COMPILE MODEL
############################################################################################
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,11 @@ def args_for_prepare_metadata(self) -> Tuple[str, ...]:
like ``insert_cached_attention`` to extract the constant arguments and add them to the
``prepare_metadata`` node/op.
"""
return tuple(self.named_standard_args.keys())
# NOTE: for now we do _not_ include input_ids since we are not guaranteed that input_ids
# is part of the graph, e.g., in situations where the graph is a submodule of the overall
# model. In such instances, the graph usually sees inputs_embeds. However, we assume for
# now that position_ids is always part of the graph.
return ("position_ids",) + self._cached_arg_names

@property
def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]:
Expand Down Expand Up @@ -466,7 +470,9 @@ def _get_cache_locations_and_pages_per_sequence(
return cache_loc_flat, pages_per_seq

@classmethod
def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
def _get_sanitized_seq_len(
cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
) -> torch.Tensor:
"""Sanitize sequence lengths.

We want to cover the following scenarios with this function:
Expand Down Expand Up @@ -499,22 +505,24 @@ def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor)
# valid cache location in the batch. This would ensure that the dummy sequences just
# repeats valid computation...
"""
_, s = input_ids.shape[:2]
num_seq = cls._get_sanitized_num_sequences(input_ids, seq_len)
_, s = input_or_position_ids.shape[:2]
num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len)
if s > 1:
return seq_len[:num_seq].detach().clone()
else:
return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)

@staticmethod
def _get_sanitized_num_sequences(input_ids: torch.Tensor, seq_len: torch.Tensor) -> int:
def _get_sanitized_num_sequences(
input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
) -> int:
"""Get number of sequences.

We makes sure that this function is compatible with both torch graph capture and cudagraph.
Both can be a bit temparamental when trying to extract the number of sequences from a tensor
with max_batch_size or max_batch_size*max_seq_len.
"""
b, s = input_ids.shape[:2]
b, s = input_or_position_ids.shape[:2]
if s > 1:
num_seq = torch.sum(seq_len > 0)
assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
Expand Down Expand Up @@ -814,7 +822,6 @@ def __call__(
class PrepareMetadataCallable(Protocol):
def __call__(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand Down Expand Up @@ -901,7 +908,6 @@ def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:

```
def prepare_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def _build_conv_state_from_sequence(input_bt_c: torch.Tensor, kernel_size: int)
# ---------------------------------------------------------------
@torch.library.custom_op("auto_deploy::cuda_causal_conv_prepare_metadata", mutates_args=())
def cuda_causal_conv_prepare_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand All @@ -67,7 +66,7 @@ def cuda_causal_conv_prepare_metadata(

Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
"""
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)

seq_start = torch.zeros_like(seq_len_sanitized)
Expand All @@ -81,9 +80,9 @@ def cuda_causal_conv_prepare_metadata(

@cuda_causal_conv_prepare_metadata.register_fake
def cuda_causal_conv_prepare_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
):
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)
return (
torch.empty_like(seq_len_sanitized),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):

@torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=())
def prepare_flashinfer_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand All @@ -174,7 +173,7 @@ def prepare_flashinfer_metadata(
_GlobalFlashInferPlanner.reset()

# retrieve sanitzed metadata
seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len)

# prepare flashinfer-style metadata
Expand Down Expand Up @@ -214,9 +213,9 @@ def prepare_flashinfer_metadata(
# As SequenceInfo._get_sanitized_num_sequences could break in fake mode
@prepare_flashinfer_metadata.register_fake
def prepare_flashinfer_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
):
seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)
return (
qo_indptr, # qo_indptr
Expand Down
5 changes: 2 additions & 3 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def fused_flattened_mla_with_cache_fake(
"auto_deploy::triton_attention_prepare_fused_mla_metadata", mutates_args=()
)
def prepare_fused_mla_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand All @@ -184,7 +183,7 @@ def prepare_fused_mla_metadata(
slot_idx: torch.Tensor,
page_size: int,
) -> List[torch.Tensor]:
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
seq_start = torch.zeros_like(seq_len[:num_seq])
seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0)
return (
Expand All @@ -197,7 +196,7 @@ def prepare_fused_mla_metadata(

@prepare_fused_mla_metadata.register_fake
def prepare_fused_mla_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
):
return (
torch.empty_like(seq_len),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ def torch_backend_mha_with_cache_fake(

@torch.library.custom_op("auto_deploy::torch_cached_attention_prepare_metadata", mutates_args=())
def torch_backend_prepare_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand All @@ -366,7 +365,7 @@ def torch_backend_prepare_metadata(
page_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for torch backend attention (similar to triton backend)."""
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
seq_start = torch.zeros_like(seq_len[:num_seq])
seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0)
return (
Expand All @@ -379,9 +378,9 @@ def torch_backend_prepare_metadata(

@torch_backend_prepare_metadata.register_fake
def torch_backend_prepare_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
):
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
return (
torch.empty_like(seq_len[:num_seq]),
torch.empty_like(input_pos[:num_seq]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def _torch_causal_conv1d_decode(

@torch.library.custom_op("auto_deploy::torch_causal_conv_prepare_metadata", mutates_args=())
def torch_causal_conv_prepare_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand All @@ -153,7 +152,7 @@ def torch_causal_conv_prepare_metadata(

Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
"""
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)

seq_start = torch.zeros_like(seq_len_sanitized)
Expand All @@ -167,9 +166,9 @@ def torch_causal_conv_prepare_metadata(

@torch_causal_conv_prepare_metadata.register_fake
def torch_causal_conv_prepare_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
):
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)
return (
torch.empty_like(seq_len_sanitized),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def _update_ssm_state_cache(ssm_cache: torch.Tensor, ssm_state: torch.Tensor) ->

@torch.library.custom_op("auto_deploy::torch_ssm_prepare_metadata", mutates_args=())
def _torch_ssm_prepare_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand All @@ -127,7 +126,7 @@ def _torch_ssm_prepare_metadata(
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
"""
# Determine number of active sequences and compute seq_start boundaries
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)

seq_start = torch.zeros_like(seq_len_sanitized)
Expand All @@ -142,10 +141,10 @@ def _torch_ssm_prepare_metadata(

@_torch_ssm_prepare_metadata.register_fake
def _torch_ssm_prepare_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
):
# Use the same sanitization logic to determine sizes in fake mode
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)
return (
torch.empty_like(seq_len_sanitized),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ def flattened_mha_fake(
"auto_deploy::triton_attention_prepare_fused_mha_metadata", mutates_args=()
)
def prepare_fused_mha_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand All @@ -294,7 +293,7 @@ def prepare_fused_mha_metadata(
page_size: int,
) -> List[torch.Tensor]:
# TODO: maybe use slot_idx instead of pages_per_seq??
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
seq_start = torch.zeros_like(seq_len[:num_seq])
seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0)
return (
Expand All @@ -309,9 +308,9 @@ def prepare_fused_mha_metadata(
# SequenceInfo._get_sanitized_num_sequences could break in fake mode
@prepare_fused_mha_metadata.register_fake
def prepare_fused_mha_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
):
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
return (
torch.empty_like(seq_len[:num_seq]),
torch.empty_like(input_pos[:num_seq]),
Expand Down
Loading