diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 1684fe08c9a..74675b725c6 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -6,6 +6,7 @@ transforms: ############################################################################################ build_model: stage: factory + run_per_gm: false device: meta # nothing to clean up run_graph_cleanup: false @@ -14,8 +15,8 @@ transforms: stage: export clone_state_dict: false strict: false - # nothing to clean up - run_graph_cleanup: false + run_per_gm: false + run_graph_cleanup: true requires_clean_graph: false cleanup_noop_slice: stage: post_export @@ -35,6 +36,7 @@ transforms: run_shape_prop: true match_eager_attention: stage: pattern_matcher + requires_shape_prop: true match_grouped_attention: stage: pattern_matcher match_attention_layout: @@ -87,8 +89,10 @@ transforms: ############################################################################################ load_weights: stage: weight_load + run_per_gm: false move_inputs_to_device: stage: weight_load + run_per_gm: false ############################################################################################ # RUN POST-LOAD FUSION AND OPTIMIZATIONS ############################################################################################ @@ -138,10 +142,13 @@ transforms: attn_backend: cuda_causal_conv initialize_cache: stage: cache_init + run_per_gm: false resize_kv_cache: stage: cache_init + run_per_gm: false ############################################################################################ # COMPILE MODEL ############################################################################################ compile_model: stage: compile + run_per_gm: false diff --git a/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml b/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml index 5b32f81672d..529a3bb5879 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml @@ -6,23 +6,29 @@ transforms: ############################################################################################ build_and_load_factory_model: stage: factory + run_per_gm: false ############################################################################################ # MOVE ARGUMENTS TO DEVICE ############################################################################################ move_inputs_to_device: stage: weight_load + run_per_gm: false ############################################################################################ # SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES ############################################################################################ detect_hf_attn_layers: stage: cache_init + run_per_gm: false transformers_replace_cached_attn: stage: cache_init attn_backend: flashinfer + run_per_gm: false initialize_cache: stage: cache_init + run_per_gm: false resize_kv_cache: stage: cache_init + run_per_gm: false ############################################################################################ # COMPILE MODEL ############################################################################################ diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index d61bb8854e8..7c2d28fbce1 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -325,7 +325,11 @@ def args_for_prepare_metadata(self) -> Tuple[str, ...]: like ``insert_cached_attention`` to extract the constant arguments and add them to the ``prepare_metadata`` node/op. """ - return tuple(self.named_standard_args.keys()) + # NOTE: for now we do _not_ include input_ids since we are not guaranteed that input_ids + # is part of the graph, e.g., in situations where the graph is a submodule of the overall + # model. In such instances, the graph usually sees inputs_embeds. However, we assume for + # now that position_ids is always part of the graph. + return ("position_ids",) + self._cached_arg_names @property def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]: @@ -466,7 +470,9 @@ def _get_cache_locations_and_pages_per_sequence( return cache_loc_flat, pages_per_seq @classmethod - def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor: + def _get_sanitized_seq_len( + cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor + ) -> torch.Tensor: """Sanitize sequence lengths. We want to cover the following scenarios with this function: @@ -499,22 +505,24 @@ def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) # valid cache location in the batch. This would ensure that the dummy sequences just # repeats valid computation... """ - _, s = input_ids.shape[:2] - num_seq = cls._get_sanitized_num_sequences(input_ids, seq_len) + _, s = input_or_position_ids.shape[:2] + num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len) if s > 1: return seq_len[:num_seq].detach().clone() else: return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device) @staticmethod - def _get_sanitized_num_sequences(input_ids: torch.Tensor, seq_len: torch.Tensor) -> int: + def _get_sanitized_num_sequences( + input_or_position_ids: torch.Tensor, seq_len: torch.Tensor + ) -> int: """Get number of sequences. We makes sure that this function is compatible with both torch graph capture and cudagraph. Both can be a bit temparamental when trying to extract the number of sequences from a tensor with max_batch_size or max_batch_size*max_seq_len. """ - b, s = input_ids.shape[:2] + b, s = input_or_position_ids.shape[:2] if s > 1: num_seq = torch.sum(seq_len > 0) assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded" @@ -814,7 +822,6 @@ def __call__( class PrepareMetadataCallable(Protocol): def __call__( self, - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -901,7 +908,6 @@ def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: ``` def prepare_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py index fbfc2fad614..b8e134be19f 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py @@ -54,7 +54,6 @@ def _build_conv_state_from_sequence(input_bt_c: torch.Tensor, kernel_size: int) # --------------------------------------------------------------- @torch.library.custom_op("auto_deploy::cuda_causal_conv_prepare_metadata", mutates_args=()) def cuda_causal_conv_prepare_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -67,7 +66,7 @@ def cuda_causal_conv_prepare_metadata( Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). """ - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) seq_start = torch.zeros_like(seq_len_sanitized) @@ -81,9 +80,9 @@ def cuda_causal_conv_prepare_metadata( @cuda_causal_conv_prepare_metadata.register_fake def cuda_causal_conv_prepare_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) return ( torch.empty_like(seq_len_sanitized), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 3200a21937d..63a8c7b1547 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -155,7 +155,6 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper): @torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=()) def prepare_flashinfer_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -174,7 +173,7 @@ def prepare_flashinfer_metadata( _GlobalFlashInferPlanner.reset() # retrieve sanitzed metadata - seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len) # prepare flashinfer-style metadata @@ -214,9 +213,9 @@ def prepare_flashinfer_metadata( # As SequenceInfo._get_sanitized_num_sequences could break in fake mode @prepare_flashinfer_metadata.register_fake def prepare_flashinfer_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): - seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device) return ( qo_indptr, # qo_indptr diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py index 20100531973..ea68da9e508 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py @@ -175,7 +175,6 @@ def fused_flattened_mla_with_cache_fake( "auto_deploy::triton_attention_prepare_fused_mla_metadata", mutates_args=() ) def prepare_fused_mla_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -184,7 +183,7 @@ def prepare_fused_mla_metadata( slot_idx: torch.Tensor, page_size: int, ) -> List[torch.Tensor]: - num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) + num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) seq_start = torch.zeros_like(seq_len[:num_seq]) seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0) return ( @@ -197,7 +196,7 @@ def prepare_fused_mla_metadata( @prepare_fused_mla_metadata.register_fake def prepare_fused_mla_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): return ( torch.empty_like(seq_len), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py index 6eadb4b4466..df2d4b24c59 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py @@ -356,7 +356,6 @@ def torch_backend_mha_with_cache_fake( @torch.library.custom_op("auto_deploy::torch_cached_attention_prepare_metadata", mutates_args=()) def torch_backend_prepare_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -366,7 +365,7 @@ def torch_backend_prepare_metadata( page_size: int, ) -> List[torch.Tensor]: """Prepare metadata for torch backend attention (similar to triton backend).""" - num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) + num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) seq_start = torch.zeros_like(seq_len[:num_seq]) seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0) return ( @@ -379,9 +378,9 @@ def torch_backend_prepare_metadata( @torch_backend_prepare_metadata.register_fake def torch_backend_prepare_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): - num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) + num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) return ( torch.empty_like(seq_len[:num_seq]), torch.empty_like(input_pos[:num_seq]), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py index 522779fd183..6aaf5ecb405 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py @@ -140,7 +140,6 @@ def _torch_causal_conv1d_decode( @torch.library.custom_op("auto_deploy::torch_causal_conv_prepare_metadata", mutates_args=()) def torch_causal_conv_prepare_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -153,7 +152,7 @@ def torch_causal_conv_prepare_metadata( Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). """ - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) seq_start = torch.zeros_like(seq_len_sanitized) @@ -167,9 +166,9 @@ def torch_causal_conv_prepare_metadata( @torch_causal_conv_prepare_metadata.register_fake def torch_causal_conv_prepare_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) return ( torch.empty_like(seq_len_sanitized), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py index 4ac148e815e..6bf7eb84d14 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py @@ -113,7 +113,6 @@ def _update_ssm_state_cache(ssm_cache: torch.Tensor, ssm_state: torch.Tensor) -> @torch.library.custom_op("auto_deploy::torch_ssm_prepare_metadata", mutates_args=()) def _torch_ssm_prepare_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -127,7 +126,7 @@ def _torch_ssm_prepare_metadata( Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). """ # Determine number of active sequences and compute seq_start boundaries - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) seq_start = torch.zeros_like(seq_len_sanitized) @@ -142,10 +141,10 @@ def _torch_ssm_prepare_metadata( @_torch_ssm_prepare_metadata.register_fake def _torch_ssm_prepare_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): # Use the same sanitization logic to determine sizes in fake mode - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) return ( torch.empty_like(seq_len_sanitized), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index 56aad993a3c..34e0c5a988d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -284,7 +284,6 @@ def flattened_mha_fake( "auto_deploy::triton_attention_prepare_fused_mha_metadata", mutates_args=() ) def prepare_fused_mha_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -294,7 +293,7 @@ def prepare_fused_mha_metadata( page_size: int, ) -> List[torch.Tensor]: # TODO: maybe use slot_idx instead of pages_per_seq?? - num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) + num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) seq_start = torch.zeros_like(seq_len[:num_seq]) seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0) return ( @@ -309,9 +308,9 @@ def prepare_fused_mha_metadata( # SequenceInfo._get_sanitized_num_sequences could break in fake mode @prepare_fused_mha_metadata.register_fake def prepare_fused_mha_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): - num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) + num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) return ( torch.empty_like(seq_len[:num_seq]), torch.empty_like(input_pos[:num_seq]), diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index a0895b61d14..6915dac8540 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -4,7 +4,7 @@ """ import time -from abc import ABC, abstractmethod +from abc import ABC from contextlib import nullcontext from enum import Enum from functools import total_ordering, wraps @@ -82,7 +82,7 @@ class TransformConfig(BaseModel): ### OPTIONAL CONFIG ########################################################################### run_per_gm: bool = Field( description="Whether to run the transform per graph (sub)module or on whole module.", - default=False, + default=True, ) enabled: bool = Field( default=True, @@ -126,9 +126,11 @@ class TransformInfo(BaseModel): } skipped: bool = Field( + default=True, description="Whether the transform was skipped.", ) num_matches: int = Field( + default=0, description="Number of matches found.", ) is_clean: bool = Field( @@ -145,6 +147,32 @@ class TransformInfo(BaseModel): "tensors in the graph and it preserves the has_valid_shapes flag of the last transform.", ) + @classmethod + def from_last_info(cls, info: "TransformInfo") -> "TransformInfo": + """Create a new TransformInfo from the last transform info.""" + return cls( + is_clean=info.is_clean, + has_valid_shapes=info.has_valid_shapes, + ) + + def __or__(self, other: "TransformInfo") -> "TransformInfo": + """Merge two TransformInfo objects.""" + return TransformInfo( + skipped=self.skipped and other.skipped, # we only count skipped if both were skipped + num_matches=self.num_matches + other.num_matches, + is_clean=self.is_clean or other.is_clean, + has_valid_shapes=self.has_valid_shapes or other.has_valid_shapes, + ) + + def __and__(self, other: "TransformInfo") -> "TransformInfo": + """Merge two TransformInfo objects.""" + return TransformInfo( + skipped=self.skipped and other.skipped, # we only count skipped if both were skipped + num_matches=self.num_matches + other.num_matches, + is_clean=self.is_clean and other.is_clean, + has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes, + ) + TransformHistory = Dict[str, TransformInfo] @@ -248,7 +276,7 @@ def from_kwargs(cls, **kwargs) -> "BaseTransform": @final def __call__( self, - gm: nn.Module, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, @@ -256,13 +284,13 @@ def __call__( """Apply the transform to the graph. Args: - gm: The graph module to apply the transform to. + mod: The model to apply the transform to. cm: The cached sequence interface defining the sequence interface. factory: The model factory used to build the model. shared_config: Global info shared between multiple transforms. Returns: - GraphModule: The transformed graph module. + nn.Module: The transformed model. NOTE: The transform can/should modify the graph module in place if possible. Returning the graph is mostly to standardize the interface for transforms that cannot modify the graph @@ -276,13 +304,16 @@ def __call__( t_name = self.get_transform_key() # retrieve autodeploy metadata from the graphmodule - autodeploy_meta = self._get_autodeploy_meta(gm) + autodeploy_meta = self._get_autodeploy_meta(mod) # retrieve transform history and last transform info history: TransformHistory = autodeploy_meta.get(self._history_key, {}) h_keys = list(history.keys()) # preserves order of insertion/transform execution info_last = history[h_keys[-1]] if h_keys else TransformInfo(skipped=False, num_matches=0) + # initialize new info object + info = TransformInfo.from_last_info(info_last) + # show debug info for debug config ad_logger.debug(f"{t_name} config: {self.config}") @@ -294,42 +325,47 @@ def __call__( # run or skip the transform if self.config.enabled: - # run graph pre-cleanup + # run graph pre-cleanup and update info object elapsed_time_pre_cleanup = -time.time() - is_clean_pre, has_valid_shapes_pre = self._run_pre_cleanup(gm, info_last) + info = info | self._run_cleanup( + mod, + self.config.requires_clean_graph, + self.config.requires_shape_prop, + info.is_clean, + info.has_valid_shapes, + ) elapsed_time_pre_cleanup += time.time() # run the transform in a error-handling wrapper if desired elapsed_time_apply = -time.time() if self.config.skip_on_error: try: - gm, info = self._apply_per_gm(gm, cm, factory, shared_config) + mod, info_apply = self._apply_per_gm_or_whole_model( + mod, cm, factory, shared_config + ) except Exception as e: error_msg = f"Transform {t_name} failed" ad_logger.warning(f"{error_msg}: {e}") - info = TransformInfo(skipped=True, num_matches=0) + info_apply = TransformInfo(skipped=True, num_matches=0) else: # handle this here normally to improve debugging and error message - gm, info = self._apply_per_gm(gm, cm, factory, shared_config) + mod, info_apply = self._apply_per_gm_or_whole_model(mod, cm, factory, shared_config) elapsed_time_apply += time.time() # we cannot say it's clean if the previous wasn't clean even if this one is # create new info object with updated cleanup status - info_dict = info.model_dump() - info_dict["is_clean"] &= is_clean_pre - info_dict["has_valid_shapes"] &= has_valid_shapes_pre - info = TransformInfo(**info_dict) + info = info & info_apply # run graph post-cleanup elapsed_time_post_cleanup = -time.time() - info = self._run_post_cleanup(gm, info) + info = info | self._run_cleanup( + mod, + self.config.run_graph_cleanup, + self.config.run_shape_prop, + info.is_clean, + info.has_valid_shapes, + ) elapsed_time_post_cleanup += time.time() - else: - # skip the transform and set info object using the last transform info - info_dict = info_last.model_dump() - info_dict["skipped"] = True - info_dict["num_matches"] = 0 - info = TransformInfo(**info_dict) elapsed_time_total += time.time() @@ -348,36 +384,37 @@ def __call__( f"post_cleanup={elapsed_time_post_cleanup:.3f}s", ] self._log_info(", ".join(log_msgs_timing)) - ad_logger.debug(f"Graph after {t_name}: {gm}") + ad_logger.debug(f"Model after {t_name}: {mod}") # update + store new meta data history[t_name] = info autodeploy_meta[self._history_key] = history - self._set_autodeploy_meta(gm, autodeploy_meta) + self._set_autodeploy_meta(mod, autodeploy_meta) # return the graph module - return gm + return mod @final - def _apply_per_gm( + def _apply_per_gm_or_whole_model( self, - gm: nn.Module, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: if not self.config.run_per_gm: - return self._apply(gm, cm, factory, shared_config) + return self._apply_to_full_model(mod, cm, factory, shared_config) # just run it on first graph module we are encountering for now... - for k, graph_sub in named_graphmodules(gm): - graph_sub, info = self._apply(graph_sub, cm, factory, shared_config) + info = TransformInfo() + for k, graph_sub in named_graphmodules(mod): + graph_sub, info_apply = self._apply(graph_sub, cm, factory, shared_config) if k == "": - gm = graph_sub + mod = graph_sub else: - gm.set_submodule(k, graph_sub) - break - return gm, info + mod.set_submodule(k, graph_sub) + info = info & info_apply + return mod, info @final def _log_info(self, *args: any): @@ -385,88 +422,56 @@ def _log_info(self, *args: any): ad_logger.info(*args) @final - def _get_autodeploy_meta(self, gm: GraphModule) -> AutodeployMeta: + def _get_autodeploy_meta(self, mod: nn.Module) -> AutodeployMeta: """Get the autodeploy metadata from the graphmodule.""" - if not hasattr(gm, "meta"): - gm.meta = {} - return gm.meta.get(self._autodeploy_meta_key, {}) + if not hasattr(mod, "meta"): + mod.meta = {} + return mod.meta.get(self._autodeploy_meta_key, {}) @final - def _set_autodeploy_meta(self, gm: GraphModule, autodeploy_meta: AutodeployMeta) -> None: + def _set_autodeploy_meta(self, mod: nn.Module, autodeploy_meta: AutodeployMeta) -> None: """Set the autodeploy metadata in the graphmodule.""" - if not hasattr(gm, "meta"): - gm.meta = {} - gm.meta[self._autodeploy_meta_key] = autodeploy_meta + if not hasattr(mod, "meta"): + mod.meta = {} + mod.meta[self._autodeploy_meta_key] = autodeploy_meta @final - def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> Tuple[bool, bool]: + def _run_cleanup( + self, + mod: nn.Module, + clean_graph: bool, + clean_shape: bool, + is_clean: bool, + has_valid_shapes: bool, + ) -> TransformInfo: """Run graph cleanup before the transform. Args: - gm: The graph module to run cleanup on. - info: The last transform info. + mod: The model to run cleanup on. + clean_graph: Whether we want a clean graph after the transform. + clean_shape: Whether we want clean shapes after the transform. + is_clean: The current cleanup status. + has_valid_shapes: The current shape propagation status. Returns: - A tuple of (is_clean, has_valid_shapes) indicating the cleanup status after the - pre-cleanup. - - This is used to ensure the transform is applied to a clean graph as needed by the transform. + An info object indicating the cleanup status after this function is called. """ - if not self.config.requires_clean_graph: - return info.is_clean, info.has_valid_shapes - - is_clean = info.is_clean - has_valid_shapes = is_clean and info.has_valid_shapes - - use_meta = isinstance(gm, GraphModule) and placeholders_on_meta(gm) - # check if run cleanup depending on the config and info - if self.config.requires_shape_prop and not has_valid_shapes: - self._log_info("running pre-cleanup with shape_prop") - canonicalize_graph(gm) - with lift_to_meta(gm) if use_meta else nullcontext(): - run_shape_prop(gm) + if clean_shape and not (is_clean and has_valid_shapes): + self._log_info("running graph cleanup (with shape_prop)") + canonicalize_graph(mod) + with lift_to_meta(mod) if placeholders_on_meta(mod) else nullcontext(): + run_shape_prop(mod) is_clean = True has_valid_shapes = True - elif self.config.requires_clean_graph and not is_clean: - self._log_info("running pre-cleanup (no shape_prop)") - canonicalize_graph(gm) + elif clean_graph and not is_clean: + self._log_info("running graph cleanup (no shape_prop)") + canonicalize_graph(mod) is_clean = True + has_valid_shapes = False - return is_clean, has_valid_shapes - - @final - def _run_post_cleanup(self, gm: GraphModule, info: TransformInfo) -> TransformInfo: - """Run graph cleanup after the transform. + return TransformInfo(is_clean=is_clean, has_valid_shapes=has_valid_shapes) - Cleanup is done as requested in the config and we will update the graph module and info - accordingly. - - Returns: - Updated TransformInfo with cleanup status. - """ - if not self.config.run_graph_cleanup: - return info - - use_meta = isinstance(gm, GraphModule) and placeholders_on_meta(gm) - - # check if run cleanup depending on the config and info - if self.config.run_shape_prop and not (info.is_clean and info.has_valid_shapes): - self._log_info("running post-cleanup with shape_prop") - canonicalize_graph(gm) - with lift_to_meta(gm) if use_meta else nullcontext(): - run_shape_prop(gm) - elif self.config.run_graph_cleanup and not info.is_clean: - self._log_info("running post-cleanup (no shape_prop)") - canonicalize_graph(gm) - - # create new info object with updated cleanup status - info_dict = info.model_dump() - info_dict["is_clean"] |= self.config.run_graph_cleanup - info_dict["has_valid_shapes"] |= self.config.run_shape_prop - return TransformInfo(**info_dict) - - @abstractmethod def _apply( self, gm: GraphModule, @@ -478,6 +483,21 @@ def _apply( This is the core method that should be implemented by subclasses. """ + raise NotImplementedError( + f"Transform {self.get_transform_key()} only supports `run_per_gm=False`." + ) + + def _apply_to_full_model( + self, + model: nn.Module, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[nn.Module, TransformInfo]: + """Apply the transform to the full model.""" + raise NotImplementedError( + f"Transform {self.get_transform_key()} only supports `run_per_gm=True`." + ) class TransformRegistry: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py index 96a81dbfec7..b166c2acd6e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py @@ -2,8 +2,8 @@ from typing import Tuple, Type +import torch.nn as nn from pydantic import Field -from torch.fx import GraphModule from ...models import ModelFactory, hf from ...shim.interface import CachedSequenceInterface @@ -36,23 +36,20 @@ class BuildModel(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return BuildModelConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: # build the model model = factory.build_model(self.config.device) - # as wrapper to satisfy the interface we will register the model as a submodule - gm.add_module("factory_model", model) - - # by convention, we say this fake graph module is always clean + # by convention, we say the model is always clean info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) - return gm, info + return model, info @TransformRegistry.register("build_and_load_factory_model") @@ -68,22 +65,19 @@ class BuildAndLoadFactoryModel(BuildModel): config: BuildModelConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: # load model with auto sharding assert isinstance(factory, hf.AutoModelFactory), "Only HF models are supported." # build and load the model model = factory.build_and_load_model(self.config.device) - # as wrapper to satisfy the interface we will register the model as a submodule - gm.add_module("factory_model", model) - # this ensures that extra_args are passed in as they are received instead of enforcing the # registered extra_args cm.info.use_strict_args = False @@ -95,4 +89,4 @@ def _apply( # by convention, we say this fake graph module is always clean info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) - return gm, info + return model, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py index a77fbd3ac85..d9dff807e18 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py @@ -1,7 +1,7 @@ from typing import List, Literal, Optional, Tuple, Type +import torch.nn as nn from pydantic import Field -from torch.fx import GraphModule from ...compile import CompileBackendRegistry from ...models.factory import ModelFactory @@ -39,18 +39,18 @@ class CompileModel(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return CompileModelConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: cm.info.set_generate_only_batch() compiler_cls = CompileBackendRegistry.get(self.config.compile_backend) - egm_compiled = compiler_cls( - gm, + mod_compiled = compiler_cls( + mod, args=(), kwargs=cm.named_args, max_batch_size=cm.info.max_batch_size, @@ -62,4 +62,4 @@ def _apply( # store info object about the transform info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) - return egm_compiled, info + return mod_compiled, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index e92d842958f..fffb8a25a0e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -2,8 +2,8 @@ from typing import List, Optional, Tuple, Type +import torch.nn as nn from pydantic import Field -from torch.fx import GraphModule from ...export import torch_export_to_gm from ...models.factory import ModelFactory @@ -49,25 +49,19 @@ class ExportToGM(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return ExportToGMConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: - # at this point we assume the gm is just a dummy graph module - assert len(gm.graph.nodes) == 0, "Expected empty graph module." - - # retrieve the actual model from the dummy graph module - model = gm.get_submodule("factory_model") - + ) -> Tuple[nn.Module, TransformInfo]: # set the example sequence cm.info.set_example_sequence(**factory.get_example_inputs()) # export the model to a graph module gm = torch_export_to_gm( - model, + mod, args=(), kwargs=cm.named_args, dynamic_shapes=cm.named_dynamic_shapes, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index be8bf4c7b3e..945b375139d 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple, Type import torch +import torch.nn as nn from pydantic import Field from torch.fx import GraphModule, Node @@ -238,13 +239,13 @@ class ResizeKVCache(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return ResizeKVCacheConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: free_mem_ratio = self.config.free_mem_ratio def _get_mem_info_in_mb(): @@ -262,7 +263,7 @@ def _get_mem_info_in_mb(): if free_mem_ratio == 0.0: self._log_info(f"Skipping cache resize for {free_mem_ratio=}") - return gm, TransformInfo( + return mod, TransformInfo( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) @@ -274,7 +275,7 @@ def _get_mem_info_in_mb(): free_mem_pre, _ = _get_mem_info_in_mb() self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}") - gm(**cm.named_args) + mod(**cm.named_args) free_mem_post, _ = _get_mem_info_in_mb() self._log_info(f"Free memory after forward pass (MB): {free_mem_post}") @@ -310,18 +311,18 @@ def _get_mem_info_in_mb(): has_valid_shapes=True, ) - return gm, info + return mod, info @TransformRegistry.register("initialize_cache") class InitializeCache(BaseTransform): - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: num_caches = cm.initialize_caches() self._log_info(f"Initialized {num_caches} caches for cached attention") @@ -329,4 +330,4 @@ def _apply( skipped=False, num_matches=num_caches, is_clean=True, has_valid_shapes=True ) - return gm, info + return mod, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py index ed218061ef4..bcc3db30256 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py @@ -6,7 +6,8 @@ import torch import torch.fx as fx -from torch.fx import GraphModule, Node +import torch.nn as nn +from torch.fx import Graph, GraphModule, Node from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -106,23 +107,24 @@ class DetectHFAttnLayers(BaseTransform): This is achieved by running a single forward pass to profile the model. """ - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - model = gm.factory_model - # Register profiler attn operator ALL_ATTENTION_FUNCTIONS.register("ad_profile_mha", fake_profiler_mha) + # let's start a fake graph module for making tracing/profiling easier + mod._gm = GraphModule(nn.Module(), Graph()) + # run the forward pass with the profiling function - with switch_attn_implementation(model.config, "ad_profile_mha"): + with switch_attn_implementation(mod.config, "ad_profile_mha"): # update the graph module with the fake attn nodes during the profiling run - profiling_metadata = {"gm": gm, "num_matches": 0} - model.forward(**cm.named_args, profiling_metadata=profiling_metadata) + profiling_metadata = {"gm": mod._gm, "num_matches": 0} + mod.forward(**cm.named_args, profiling_metadata=profiling_metadata) info = TransformInfo( skipped=False, @@ -131,7 +133,7 @@ def _apply( has_valid_shapes=True, ) - return gm, info + return mod, info def get_cached_attn( @@ -188,9 +190,9 @@ def cached_attn( return cached_attn -def forward_with_prepare_metadata(gm: GraphModule, **cm_kwargs): +def forward_with_prepare_metadata(mod: nn.Module, **cm_kwargs): """Run prepare_metadata as pre-processing step, add to kwargs, and then run regular forward.""" - + gm = mod._gm if hasattr(gm, "_prepare_metadata_info"): # collect args+constant args args = [cm_kwargs[k] for k in gm._prepare_metadata_info["arg_names"]] @@ -201,7 +203,7 @@ def forward_with_prepare_metadata(gm: GraphModule, **cm_kwargs): return_names = gm._prepare_metadata_info["return_names"] cm_kwargs.update({k: v for k, v in zip(return_names, metadata)}) - return gm.factory_model.forward(**cm_kwargs) + return mod._original_forward(**cm_kwargs) # TODO: how running different kv cache transforms look like? This one below wouldn't work if we @@ -242,28 +244,29 @@ def _insert_cached_attn_node( attn_node.meta["metadata_cache_buffer_keys"] = (*meta_nodes, *cache_nodes, *buffer_nodes) attn_node.meta["constants"] = constants - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: # switch to cached attn inputs from now cm.info.switch_to_cached_attn_inputs() - # run actual insert cached attn transform - gm, info = super()._apply(gm, cm, factory, shared_config) + # run actual insert cached attn transform with fake graph module + mod._gm, info = super()._apply(mod._gm, cm, factory, shared_config) # register cached attn operator and switch to cached forward function ALL_ATTENTION_FUNCTIONS.register("ad_cached_mha", get_cached_attn(self.attn_descriptor)) - gm.forward = types.MethodType(forward_with_prepare_metadata, gm) + mod._original_forward = mod.forward + mod.forward = types.MethodType(forward_with_prepare_metadata, mod) - # switch to cached attn implementation but _only_ for modules/configs that have a cached + # switch to cached attn implementation but _only_ for submodules/configs that have a cached # attn node (we don't want to switch to cached attn implementation for all modules) - for mod in gm.factory_model.modules(): - if hasattr(mod, "_node_ref"): - mod.config._attn_implementation = "ad_cached_mha" + for submod in mod.modules(): + if hasattr(submod, "_node_ref"): + submod.config._attn_implementation = "ad_cached_mha" # we assume graph is clean again by definition info_dict = info.model_dump() @@ -271,4 +274,4 @@ def _apply( info_dict["has_valid_shapes"] = True info = TransformInfo(**info_dict) - return gm, info + return mod, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py b/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py index 4967e6638e9..fb229bdd56c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py @@ -2,8 +2,8 @@ from typing import Optional, Tuple, Type +import torch.nn as nn from pydantic import Field -from torch.fx import GraphModule from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface @@ -36,22 +36,22 @@ class LoadWeightsToDevice(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return MoveDeviceConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: factory.load_or_random_init( - gm, + mod, device=self.config.adconfig_checkpoint_device or self.config.device, ) - move_to_device(gm, self.config.device) + move_to_device(mod, self.config.device) info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True) - return gm, info + return mod, info @TransformRegistry.register("move_inputs_to_device") @@ -64,15 +64,15 @@ class LoadFactoryModelWeights(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return MoveDeviceConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: cm.to(self.config.device) info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True) - return gm, info + return mod, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py index 53659bf8140..31087dbd436 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py @@ -4,7 +4,6 @@ import torch.distributed as dist import torch.nn as nn -from torch.fx import Graph, GraphModule from ..distributed import common as dist_ad from ..models.factory import ModelFactory @@ -44,41 +43,32 @@ def _clean_config(self, config: InferenceOptimizerConfig) -> StrictInferenceOpti # return strict config return strict_config - @staticmethod - def _init_gm() -> GraphModule: - """Initialize a fake graph module. - - This is a dummy graph module that will be used to kick off the transforms. - """ - return GraphModule(nn.Module(), Graph()) - - def __call__( - self, cm: CachedSequenceInterface, gm: Optional[GraphModule] = None - ) -> GraphModule: + def __call__(self, cm: CachedSequenceInterface, mod: Optional[nn.Module] = None) -> nn.Module: """Transform a model into an optimized inference model. Args: cm: The cached sequence interface defining the sequence interface. + mod: The model to transform. Returns: - A GraphModule representing the optimized inference model. + A nn.Module representing the optimized inference model. """ ############################################################################################ # RUN THROUGH CONFIGURED TRANSFORMATIONS ############################################################################################ - # start with an empty fake graph module if not provided - if gm is None: - gm = self._init_gm() + # start with an empty model if not provided + if mod is None: + mod = nn.Module() # iterate over all transforms sorted by stage in the config for t_name, t_config in self.config.items(): # instantiate transform transform = TransformRegistry.get(t_name)(t_config) # run transform - gm = transform(gm, cm, self.factory, self.shared_config) + mod = transform(mod, cm, self.factory, self.shared_config) ############################################################################################ - # RETURN OPTIMIZED GRAPH + # RETURN OPTIMIZED MODEL ############################################################################################ - return gm + return mod diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py index 273e509714c..cea1d80219a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py @@ -132,12 +132,12 @@ def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None: gm.recompile() -def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> None: +def move_to_device(mod: nn.Module, device: DeviceLikeType) -> None: """Move the entire graph module and all sub-GraphModules to the specified device.""" # get device device = torch.device(device) - for _, subgm in reversed(list(named_graphmodules(gm))): + for _, subgm in reversed(list(named_graphmodules(mod))): # recompile graph to update self generated codes in subgraph _move_single_gm_to_device(subgm, device) @@ -171,20 +171,20 @@ def _canonicalize_single_gm(gm: GraphModule) -> None: gm.graph.lint() -def canonicalize_graph(gm: GraphModule) -> None: +def canonicalize_graph(mod: nn.Module) -> None: """Canonicalize the graph of the given GraphModule. Args: - gm: The GraphModule to canonicalize. + mod: The model containing GraphModules to canonicalize. Returns: - The canonicalized (cleaned-up) GraphModule. + The canonicalized (cleaned-up) model. """ - ad_logger.debug(f"Before canonicalizing: {gm}") + ad_logger.debug(f"Before canonicalizing: {mod}") - for _, subgm in reversed(list(named_graphmodules(gm))): + for _, subgm in reversed(list(named_graphmodules(mod))): _canonicalize_single_gm(subgm) - ad_logger.debug(f"After canonicalizing: {gm}") + ad_logger.debug(f"After canonicalizing: {mod}") def _run_shape_prop_single_gm( @@ -216,7 +216,7 @@ def _run_shape_prop_single_gm( def run_shape_prop( - gm: GraphModule, + mod: nn.Module, args_static: Optional[Tuple[Any, ...]] = None, ) -> None: """Run FakeTensor-based shape propagation on the given GraphModule and its submodules. @@ -228,19 +228,19 @@ def run_shape_prop( are synthesized from the static arguments. Args: - gm: The top-level GraphModule on which to run shape propagation. All nested - GraphModules are processed in reverse topological order. + mod: The top-level model containing GraphModules on which to run shape propagation. All + nested GraphModules are processed in reverse topological order. args_static: Optional tuple of concrete tensors used to create FakeTensors when placeholder metadata is missing. Only applied to the top-level GraphModule; submodules reuse their existing placeholder metadata. """ - ad_logger.debug(f"Before running shape propagation: {gm}") + ad_logger.debug(f"Before running shape propagation: {mod}") - for _, subgm in reversed(list(named_graphmodules(gm))): - _run_shape_prop_single_gm(subgm, args_static=args_static if subgm is gm else None) + for _, subgm in reversed(list(named_graphmodules(mod))): + _run_shape_prop_single_gm(subgm, args_static=args_static if subgm is mod else None) - ad_logger.debug(f"After running shape propagation: {gm}") + ad_logger.debug(f"After running shape propagation: {mod}") def add_graph_input( @@ -309,7 +309,7 @@ def call_post_init(spec): return in_node -def placeholders_on_meta(gm: GraphModule) -> bool: +def placeholders_on_meta(mod: nn.Module) -> bool: """ Return True if every placeholder node in the graph is on the meta device. """ @@ -324,17 +324,18 @@ def _is_meta_tensor(t) -> bool: # Fallback for objects with .is_meta attribute return bool(getattr(t, "is_meta", False)) - for n in gm.graph.nodes: - if n.op != "placeholder": - continue - val = n.meta.get("val", None) + for _, subgm in reversed(list(named_graphmodules(mod))): + for n in subgm.graph.nodes: + if n.op != "placeholder": + continue + val = n.meta.get("val", None) - # If placeholder packs multiple values, find the first tensor-like leaf - t = val - if isinstance(val, (list, tuple)): - t = next((x for x in val if hasattr(x, "device") or hasattr(x, "is_meta")), None) + # If placeholder packs multiple values, find the first tensor-like leaf + t = val + if isinstance(val, (list, tuple)): + t = next((x for x in val if hasattr(x, "device") or hasattr(x, "is_meta")), None) - if not _is_meta_tensor(t): - return False + if not _is_meta_tensor(t): + return False return True diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py index 05ac3c70d22..7ffb1709cb6 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py @@ -176,7 +176,7 @@ def test_prepare_metadata_cuda(conv_env): device = conv_env["device"] b, s = 4, 6 - input_ids = torch.randint(0, 1000, (b, s), device=device) + # input_ids = torch.randint(0, 1000, (b, s), device=device) position_ids = torch.arange(s, device=device).expand(b, -1) seq_len = torch.tensor([2, 1, 0, 0], device=device, dtype=torch.int32) input_pos = torch.tensor([0, 3, 0, 0], device=device, dtype=torch.int32) @@ -186,7 +186,6 @@ def test_prepare_metadata_cuda(conv_env): page_size = 128 out = torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata( - input_ids, position_ids, seq_len, input_pos, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py index 58811d6cd70..e68dfa4f24b 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py @@ -469,7 +469,7 @@ def test_metadata_preparation(self): batch_size, seq_len_val = 4, 8 device = self.device - input_ids = torch.randint(0, 1000, (batch_size, seq_len_val), device=device) + # input_ids = torch.randint(0, 1000, (batch_size, seq_len_val), device=device) position_ids = torch.arange(seq_len_val, device=device).expand(batch_size, -1) seq_len = torch.full((batch_size,), seq_len_val, device=device, dtype=torch.int32) input_pos = torch.zeros(batch_size, device=device, dtype=torch.int32) @@ -479,7 +479,7 @@ def test_metadata_preparation(self): # Test metadata preparation result = torch.ops.auto_deploy.torch_cached_attention_prepare_metadata( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, 128 + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, 128 ) # Verify result structure diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py index 502eb634dc3..4090821e252 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py @@ -168,7 +168,7 @@ def test_prepare_metadata(conv_env): device = conv_env["device"] b, s = 4, 6 - input_ids = torch.randint(0, 1000, (b, s), device=device) + # input_ids = torch.randint(0, 1000, (b, s), device=device) position_ids = torch.arange(s, device=device).expand(b, -1) seq_len = torch.tensor([2, 1, 0, 0], device=device, dtype=torch.int32) input_pos = torch.tensor([0, 3, 0, 0], device=device, dtype=torch.int32) @@ -178,7 +178,6 @@ def test_prepare_metadata(conv_env): page_size = 128 out = torch.ops.auto_deploy.torch_causal_conv_prepare_metadata( - input_ids, position_ids, seq_len, input_pos, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py index 601c61e5f45..3000880d435 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py @@ -183,7 +183,7 @@ def test_prepare_metadata(mamba_env): device = mamba_env["device"] b, s = 4, 6 - input_ids = torch.randint(0, 1000, (b, s), device=device) + # input_ids = torch.randint(0, 1000, (b, s), device=device) position_ids = torch.arange(s, device=device).expand(b, -1) seq_len = torch.tensor([2, 1, 0, 0], device=device, dtype=torch.int32) input_pos = torch.tensor([0, 3, 0, 0], device=device, dtype=torch.int32) @@ -193,7 +193,6 @@ def test_prepare_metadata(mamba_env): page_size = 128 out = torch.ops.auto_deploy.torch_ssm_prepare_metadata( - input_ids, position_ids, seq_len, input_pos, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index 65faaa0263a..07cb79270c8 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -92,6 +92,9 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): ), get_small_model_config_pytest_param( "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + pytest_param_kwargs={ + "marks": pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5505835") + }, attn_backend="flashinfer", compile_backend="torch-simple", ), diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index fef458e1f7a..f154f00e408 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -163,6 +163,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): { "build_model": { "stage": "factory", + "run_per_gm": False, "device": "cuda", "run_graph_cleanup": False, "requires_clean_graph": False, @@ -170,6 +171,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): "export_to_gm": { "stage": "export", "strict": False, + "run_per_gm": False, "clone_state_dict": True, "run_graph_cleanup": False, "requires_clean_graph": False,