From 3433f1a521ad8868abfd609f4ac4a0149c5fd2ab Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Wed, 1 Oct 2025 14:15:46 -0700 Subject: [PATCH 01/10] Patched incorrect starcoder tp config Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/models/patches/starcoder.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 tensorrt_llm/_torch/auto_deploy/models/patches/starcoder.py diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/starcoder.py b/tensorrt_llm/_torch/auto_deploy/models/patches/starcoder.py new file mode 100644 index 00000000000..4d28bec3d17 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/starcoder.py @@ -0,0 +1,4 @@ +from transformers.models.starcoder2.configuration_starcoder2 import Starcoder2Config + +# Remove this patch after TRT-LLM upgrades to the HF transformers version >= 4.57 +Starcoder2Config.base_model_tp_plan["layers.*.mlp.c_proj"] = "rowwise" From c10517289899ffb18d0e61cd3e6a1316f5361411 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Thu, 2 Oct 2025 09:04:05 -0700 Subject: [PATCH 02/10] Factory sharding supports quantized linear nodes Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../auto_deploy/transform/library/sharding.py | 18 +++++++++++++++--- .../_torch/auto_deploy/utils/node_utils.py | 5 +++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 07f8df00e29..ca9d76600c6 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -292,7 +292,7 @@ def detect_sharding_from_factory_config( num_simple_shards = 0 num_row_col_shards = 0 - for lin_node in filtered_nodes(gm.graph.nodes, is_linear_op): + for lin_node in filtered_nodes(gm.graph.nodes, [is_linear_op, is_fake_quantized_linear_op]): # use node's weight name to get the module name module_name = lin_node.args[1].target @@ -368,7 +368,7 @@ def detect_sharding_from_factory_config( ) num_row_col_shards += 1 else: - ad_logger.warning("Invalid sharding config. Skipping.") + ad_logger.warning(f"Unsupported sharding action {config}. Skipping.") else: # TODO: local refers to hybrid EP+TP parallelism. Not supported yet. ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.") @@ -387,7 +387,19 @@ def detect_sharding_from_factory_config( ) num_simple_shards += 1 else: - ad_logger.warning("Invalid sharding config. Skipping.") + ad_logger.warning( + f"Unsupported sharding action {config}. Fallback to simple shard" + ) + sharding_config.tp_transforms.append( + TPShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) + ) # after successful match, break the loop break diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index c0d76c0548b..83b52be6a90 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -239,6 +239,11 @@ def filtered_nodes( for node in nodes: if target(node): yield node + elif isinstance(target, Iterable) and all(isinstance(t, Callable) for t in target): + for node in nodes: + for t in target: + if t(node): + yield node else: # Handle the case where target or ops contains operations operations = ops if ops is not None else target From 8ae88698807b63679a517fe4427f0c9d832e825b Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Thu, 2 Oct 2025 10:58:57 -0700 Subject: [PATCH 03/10] sharding from user-provided config Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 6 +- tensorrt_llm/_torch/auto_deploy/llm_args.py | 17 - .../_torch/auto_deploy/transform/interface.py | 10 + .../auto_deploy/transform/library/sharding.py | 312 +++++++++--------- .../auto_deploy/utils/sharding_utils.py | 57 +++- .../library/test_bmm_sharding.py | 4 +- .../library/test_ep_sharding.py | 4 +- .../library/test_tp_sharding.py | 6 +- 8 files changed, 229 insertions(+), 187 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 0a33d0d9037..442b82dcd81 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -72,8 +72,10 @@ transforms: detect_sharding: stage: sharding simple_shard_only: false - use_sharding_from_factory: false - support_partial_config: false + # sharding_source: ['factory', 'custom', 'heuristic'] + sharding_source: ['heuristic'] + support_partial_config: true + # custom_sharding_config: 'tp_sharding.yaml' sharding_dims: ['tp', 'ep', 'bmm'] requires_shape_prop: true # TODO: (hg) need to ensure run_shape_prop after sharding. diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index e4ac0db0752..c0ca5d896f2 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -145,23 +145,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): description="The fraction of available memory to allocate for cache.", ) - simple_shard_only: bool = Field( - default=False, - description="If True, force simple sharding (all_gather) in tensor parallelism. " - "If False, auto-detect and use column+row (all_reduce) sharding when possible.", - ) - - use_sharding_from_factory: bool = Field( - default=False, - description="If True, use sharding from the model factory. If False, use sharding from the " - "AutoDeployConfig.", - ) - - sharding_dims: List[str] = Field( - default=["tp", "ep", "dp"], - description="The sharding methods to apply by the heuristic sharding stage.", - ) - compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = ( Field( default="torch-compile", diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 15799e4a802..9fcd05a2212 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -132,6 +132,16 @@ class TransformInfo(BaseModel): "tensors in the graph and it preserves the has_valid_shapes flag of the last transform.", ) + # overload += operator to concatenate TransformInfo objects + def __iadd__(self, other: "TransformInfo") -> "TransformInfo": + # since TransformInfo is frozen, instead, we return a new TransformInfo object + return TransformInfo( + skipped=self.skipped & other.skipped, + num_matches=self.num_matches + other.num_matches, + is_clean=self.is_clean & other.is_clean, + has_valid_shapes=self.has_valid_shapes & other.has_valid_shapes, + ) + TransformHistory = Dict[str, TransformInfo] diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index ca9d76600c6..de0aaaff616 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -39,6 +39,8 @@ BMMShardingInfo, EPShardingInfo, ShardingConfig, + ShardingDim, + ShardingSource, ShardingTransformInfo, SplitDimension, TPShardingInfo, @@ -106,30 +108,36 @@ def _append_simple_shard( ) -> None: # for every linear node: # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) - tp_shards: List[TPShardingInfo] = [] + num_shards = 0 for node_group in nodes_linear.values(): for n in node_group: - tp_shards.append( - TPShardingInfo.from_node( - n, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, + num_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + n, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) ) ) - sharding_config.tp_transforms.extend(tp_shards) + return num_shards class ShardingTransformConfig(TransformConfig): """Configuration for sharding transformations.""" simple_shard_only: bool = Field(default=False) - use_sharding_from_factory: bool = Field(default=False) - support_partial_config: bool = Field(default=False) - # Which sharding families to run: any subset of {"tp", "ep", "bmm"} - sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"]) + sharding_source: List[ShardingSource] = Field( + default_factory=lambda: [ShardingSource.HEURISTIC] + ) + custom_sharding_config: str = Field(default="") + # Which sharding dimensions to run: any subset of {"tp", "ep", "bmm"} + sharding_dims: List[ShardingDim] = Field( + default_factory=lambda: [ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM] + ) @TransformRegistry.register("detect_sharding") @@ -166,11 +174,11 @@ def _apply( ) -> Tuple[GraphModule, TransformInfo]: local_rank, world_size = shared_config.local_rank, shared_config.world_size - if world_size < 2: - ad_logger.info("Skipping sharding for single device") - return gm, TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + # if world_size < 2: + # ad_logger.info("Skipping sharding for single device") + # return gm, TransformInfo( + # skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + # ) assert isinstance(gm, GraphModule), "Expecting GraphModule" shared_config.sharding_config.rank = local_rank @@ -187,58 +195,55 @@ def _apply( ) shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only shared_config.sharding_config.support_partial_config = self.config.support_partial_config + shared_config.sharding_config.sharding_source = self.config.sharding_source + if isinstance(self.config.custom_sharding_config, str): + shared_config.sharding_config.read_custom_sharding_config( + self.config.custom_sharding_config + ) shared_config.sharding_config.sharding_dims = self.config.sharding_dims - shared_config.sharding_config.use_sharding_from_factory = ( - self.config.use_sharding_from_factory - ) - sharding_config = shared_config.sharding_config sharding_config.validate_config() - if ( - shared_config.sharding_config.use_sharding_from_factory - and len(shared_config.sharding_config.get_predefined_config()) > 0 - ): - ad_logger.info("Applying sharding from config") - factory_info = detect_sharding_from_factory_config(gm, sharding_config) - return gm, factory_info + info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True) + for source in shared_config.sharding_config.sharding_source: + if source == ShardingSource.FACTORY: + if len(shared_config.sharding_config.get_predefined_config()) == 0: + ad_logger.warning( + "No factory config found. Skipping sharding from factory config" + ) + continue + ad_logger.info("Applying sharding from factory config") + info += detect_sharding_from_factory_config(gm, sharding_config) - ad_logger.info( - f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}" - ) - # run TP sharding across ranks - if "tp" in shared_config.sharding_config.sharding_dims: - tp_info = detect_column_row_shard(gm, sharding_config) - else: - tp_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + elif source == ShardingSource.CUSTOM: + ad_logger.info("Applying sharding from custom config") + if shared_config.sharding_config.custom_sharding_config is None: + ad_logger.warning( + "No custom sharding config found. Skipping sharding from custom config" + ) + continue + sharding_config.predefined_config["tp_plan"] = ( + shared_config.sharding_config.custom_sharding_config + ) + info += detect_sharding_from_factory_config(gm, sharding_config) - # run EP sharding across ranks - if "ep" in shared_config.sharding_config.sharding_dims: - ep_info = detect_ep_shard(gm, sharding_config) - else: - ep_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + elif source == ShardingSource.HEURISTIC: + ad_logger.info( + f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}" + ) + # run TP sharding across ranks + if ShardingDim.TP in shared_config.sharding_config.sharding_dims: + info += detect_column_row_shard(gm, sharding_config) - # run BMM sharding across ranks - if "bmm" in shared_config.sharding_config.sharding_dims: - dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config) - else: - dp_bmm_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + # run EP sharding across ranks + if ShardingDim.EP in shared_config.sharding_config.sharding_dims: + info += detect_ep_shard(gm, sharding_config) + + # run BMM sharding across ranks + if ShardingDim.BMM in shared_config.sharding_config.sharding_dims: + info += detect_dp_bmm_shard(gm, sharding_config) - info = TransformInfo( - skipped=tp_info.skipped and ep_info.skipped and dp_bmm_info.skipped, - num_matches=tp_info.num_matches + ep_info.num_matches + dp_bmm_info.num_matches, - is_clean=tp_info.is_clean and ep_info.is_clean and dp_bmm_info.is_clean, - has_valid_shapes=tp_info.has_valid_shapes - and ep_info.has_valid_shapes - and dp_bmm_info.has_valid_shapes, - ) return gm, info @@ -288,7 +293,6 @@ def detect_sharding_from_factory_config( "o_proj", ] - num_shards = 0 num_simple_shards = 0 num_row_col_shards = 0 @@ -310,32 +314,34 @@ def detect_sharding_from_factory_config( pattern_string = pattern_string.replace("*", "@") pattern_regex = re.escape(pattern_string).replace("@", ".*") if re.match(pattern_regex, module_name): - num_shards += 1 # we have a match. Get the config for this layer config = tp_plan[key] if config == "colwise": - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - lin_node, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op=None, - min_local_shape=min_local_shape, + num_row_col_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + ) ) ) elif config == "rowwise": - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - lin_node, - split_dim=SplitDimension.ROW, - rank=rank, - world_size=world_size, - dist_op="all_reduce", - min_local_shape=min_local_shape, + num_row_col_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_reduce", + min_local_shape=min_local_shape, + ) ) ) - num_row_col_shards += 1 elif "sequence" in config: # TODO: Sequence parallelism is not supported yet. ad_logger.warning("Sequence parallelism is not supported yet. Skipping.") @@ -345,28 +351,31 @@ def detect_sharding_from_factory_config( if "shared" in module_name: col_row_action = config.replace("local_", "") if col_row_action == "colwise": - sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=lin_node.name, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op=None, - min_local_shape=min_local_shape, + num_row_col_shards += int( + sharding_config.append_TP( + TPShardingInfo( + target_node=lin_node.name, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + ) ) ) elif col_row_action == "rowwise": - sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=lin_node.name, - split_dim=SplitDimension.ROW, - rank=rank, - world_size=world_size, - dist_op="all_reduce", - min_local_shape=min_local_shape, + num_row_col_shards += int( + sharding_config.append_TP( + TPShardingInfo( + target_node=lin_node.name, + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_reduce", + min_local_shape=min_local_shape, + ) ) ) - num_row_col_shards += 1 else: ad_logger.warning(f"Unsupported sharding action {config}. Skipping.") else: @@ -375,64 +384,44 @@ def detect_sharding_from_factory_config( elif "gather" in config: # Simple shard (row + all_gather) - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - lin_node, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, + num_simple_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) ) ) - num_simple_shards += 1 else: ad_logger.warning( f"Unsupported sharding action {config}. Fallback to simple shard" ) - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - lin_node, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, + num_simple_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) ) ) # after successful match, break the loop break ad_logger.info( - f"Applied {num_shards} TP shards (simple: {num_simple_shards}, " - f"row-col pattern: {num_row_col_shards})" + f"Applied {num_simple_shards + num_row_col_shards} TP shards (simple: {num_simple_shards}, " + f"row-col: {num_row_col_shards})" ) num_matches = len(sharding_config.tp_transforms) - if sharding_config.support_partial_config: - ad_logger.info( - f"Partial factory config applied only for TP. " - f"Applying heuristics for {sharding_config.sharding_dims}." - ) - - # run EP sharding across ranks - if "ep" in sharding_config.sharding_dims: - ep_info = detect_ep_shard(gm, sharding_config) - else: - ep_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) - - # run BMM sharding across ranks - if "bmm" in sharding_config.sharding_dims: - dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config) - else: - dp_bmm_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) - num_matches += ep_info.num_matches + dp_bmm_info.num_matches - return TransformInfo( skipped=False, num_matches=num_matches, @@ -515,7 +504,6 @@ def detect_column_row_shard( # col_split (dim 1) 2nd group + all_reduce output of 2nd group # 3. Linear nodes that are not in two groups or we cannot account for all nodes: # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) output - num_shards = 0 num_simple_shards = 0 num_row_col_shards = 0 for n_start, n_end in zip(boundary_nodes[:-1], boundary_nodes[1:]): @@ -542,19 +530,19 @@ def detect_column_row_shard( if len(nodes_linear) == 0: continue - num_shards += 1 - if sharding_config.simple_shard_only: ad_logger.debug(f"Forcing Simple Shard: Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += int( + _append_simple_shard(nodes_linear, rank, world_size, sharding_config) + ) continue # simple shard when we have != 2 groups of linear nodes if len(nodes_linear) != 2: ad_logger.debug(f"Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += int( + _append_simple_shard(nodes_linear, rank, world_size, sharding_config) + ) continue # let's look at the unnacounted nodes. They are okay as long as they fall before the @@ -584,8 +572,9 @@ def detect_column_row_shard( # check if any unaccounted nodes are left. If so, do a simply shard if unaccounted_nodes or attention_related_nodes: ad_logger.debug(f"Unaccounted nodes: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += int( + _append_simple_shard(nodes_linear, rank, world_size, sharding_config) + ) continue # If we can account for all sharded nodes, we can do a two-way shard @@ -597,8 +586,9 @@ def detect_column_row_shard( # Column-row shard boundary region detection is probably wrong - there should be # only one attention operation. Fall back to simple shard. ad_logger.debug(f"More than one attention node: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += int( + _append_simple_shard(nodes_linear, rank, world_size, sharding_config) + ) continue # Extract head dimension. We cannot shard below the head_dim size. # Assume that head_dim is the last (innermost) dimension of the tensor @@ -611,18 +601,20 @@ def detect_column_row_shard( dist_op = "all_reduce" else: dist_op = None - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - n, - split_dim=i, - rank=rank, - world_size=world_size, - dist_op=dist_op, - min_local_shape=min_local_shape, + num_row_col_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + n, + split_dim=i, + rank=rank, + world_size=world_size, + dist_op=dist_op, + min_local_shape=min_local_shape, + ) ) ) - num_row_col_shards += 1 + num_shards = num_simple_shards + num_row_col_shards ad_logger.info( f"Found {num_shards} TP shards (simple: {num_simple_shards}, row-col: {num_row_col_shards})" ) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 2e4c8a9e61e..ce9b3ae7273 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -1,14 +1,17 @@ """Sharding config definitions for the inference optimizer.""" +import json import math import operator from abc import ABC, abstractmethod -from enum import IntEnum +from enum import Enum, IntEnum from functools import partial +from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Sequence import torch import torch.nn as nn +import yaml from pydantic import BaseModel, ConfigDict, Field, model_validator from torch.fx import GraphModule, Node @@ -834,6 +837,22 @@ def _resolve_ep_cls_from_node(node: Node) -> type[EPShardingInfo]: return EPShardingInfo +class ShardingSource(Enum): + """Enum for sharding source.""" + + HEURISTIC = "heuristic" + FACTORY = "factory" + CUSTOM = "custom" + + +class ShardingDim(Enum): + """Enum for sharding dimension.""" + + TP = "tp" + EP = "ep" + BMM = "bmm" + + class ShardingConfig(BaseModel): """Configuration for sharding the model.""" @@ -841,9 +860,12 @@ class ShardingConfig(BaseModel): rank: int = Field(default=0) world_size: int = Field(default=1) predefined_config: Optional[Dict[str, Any]] = None + custom_sharding_config: Optional[Dict[str, Any]] = None simple_shard_only: bool = Field(default=False) - use_sharding_from_factory: bool = False support_partial_config: bool = False + sharding_source: List[ShardingSource] = Field( + default_factory=lambda: [ShardingSource.HEURISTIC] + ) sharding_dims: List[str] = Field(default_factory=list) tp_transforms: List[TPShardingInfo] = Field(default_factory=list) bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) @@ -859,6 +881,37 @@ def _validate_and_normalize(self): self.validate_config() return self + def read_custom_sharding_config(self, config_path: str) -> bool: + """Read the custom sharding config from the given path. + + Supports both JSON and YAML file formats. The format is auto-detected + based on the file extension (.json, .yaml, .yml). + """ + path = Path(config_path) + + if not path.exists(): + ad_logger.warning(f"Sharding config file not found: {config_path}") + return False + + with open(config_path, "r") as f: + if path.suffix.lower() in [".yaml", ".yml"]: + self.custom_sharding_config = yaml.safe_load(f) + elif path.suffix.lower() == ".json": + self.custom_sharding_config = json.load(f) + else: + ad_logger.warning(f"Unsupported sharding config file format: {path.suffix}") + return True + + def append_TP(self, tp_transform: TPShardingInfo) -> bool: + """Append a TP transform only if that node was + not sharded before. Do not overwrite existing transforms. + """ + for existing_transform in self.tp_transforms: + if existing_transform.target_node == tp_transform.target_node: + return False + self.tp_transforms.append(tp_transform) + return True + def validate_config(self) -> bool: if self.factory_source != ShardingConfigSource.HUGGINGFACE: ad_logger.warning( diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py index add1b399229..efa1aeddcd4 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py @@ -66,7 +66,7 @@ def _run_job( { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": False, + "sharding_source": ["heuristic"], "sharding_dims": ["bmm"], }, "sharding_transform_executor": { @@ -128,7 +128,7 @@ def _run_pattern_detection_job( { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": False, + "sharding_source": ["heuristic"], }, }, ) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 94e236cd4e4..0dde0608b0b 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -50,7 +50,7 @@ def _get_expected_num_params(rank: int, world_size: int, num_p_og: int) -> int: { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": False, + "sharding_source": ["heuristic"], "sharding_dims": ["ep"], }, "sharding_transform_executor": { @@ -118,7 +118,7 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": False, + "sharding_source": ["heuristic"], }, }, ) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index c4554bf89b0..019a665278c 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -193,12 +193,13 @@ def verify_local_weight_sizes(gm) -> bool: op_expected = getattr(torch.ops.auto_deploy, dist_op_expected) gm = torch_export_to_gm(model, args=(x,), clone=True) + sharding_source = ["factory"] if from_config else ["heuristic"] gm_transformed = InferenceOptimizer( None, { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": from_config, + "sharding_source": sharding_source, }, "sharding_transform_executor": { "stage": "sharding", @@ -338,13 +339,14 @@ def _run_pattern_detection_job( ) ) + sharding_source = ["factory"] if from_config else ["heuristic"] # get detected transformations optimizer = InferenceOptimizer( None, { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": from_config, + "sharding_source": sharding_source, }, }, ) From 1af20e023708c84942ef22d9f440478e709b26f3 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Thu, 2 Oct 2025 13:36:29 -0700 Subject: [PATCH 04/10] fixing sharding tests Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../_torch/auto_deploy/utils/sharding_utils.py | 18 +++++++++++------- .../_utils_test/_graph_test_helpers.py | 4 +++- .../library/test_bmm_sharding.py | 2 ++ .../library/test_ep_sharding.py | 2 ++ .../library/test_tp_sharding.py | 7 ++++++- 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index ce9b3ae7273..57e533a8f96 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -893,13 +893,17 @@ def read_custom_sharding_config(self, config_path: str) -> bool: ad_logger.warning(f"Sharding config file not found: {config_path}") return False - with open(config_path, "r") as f: - if path.suffix.lower() in [".yaml", ".yml"]: - self.custom_sharding_config = yaml.safe_load(f) - elif path.suffix.lower() == ".json": - self.custom_sharding_config = json.load(f) - else: - ad_logger.warning(f"Unsupported sharding config file format: {path.suffix}") + try: + with open(config_path, "r") as f: + if path.suffix.lower() in [".yaml", ".yml"]: + self.custom_sharding_config = yaml.safe_load(f) + elif path.suffix.lower() == ".json": + self.custom_sharding_config = json.load(f) + else: + ad_logger.warning(f"Unsupported sharding config file format: {path.suffix}") + except Exception as e: + ad_logger.warning(f"Failed to read sharding config file: {e}") + return False return True def append_TP(self, tp_transform: TPShardingInfo) -> bool: diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py index b00ee2bb97a..62cbcaed782 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py @@ -267,4 +267,6 @@ def run_sharding_pattern_detection_test( print("detected_set", detected_set) print("expected_set", expected_set) - assert detected_set == expected_set, "Expected sharding pattern does not match detected pattern" + assert detected_set == expected_set, ( + f"Expected sharding pattern does not match detected pattern: {detected_set} != {expected_set}" + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py index efa1aeddcd4..784675757be 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py @@ -68,6 +68,7 @@ def _run_job( "stage": "sharding", "sharding_source": ["heuristic"], "sharding_dims": ["bmm"], + "support_partial_config": False, }, "sharding_transform_executor": { "stage": "sharding", @@ -129,6 +130,7 @@ def _run_pattern_detection_job( "detect_sharding": { "stage": "sharding", "sharding_source": ["heuristic"], + "support_partial_config": False, }, }, ) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 0dde0608b0b..c2cc7ed46bb 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -52,6 +52,7 @@ def _get_expected_num_params(rank: int, world_size: int, num_p_og: int) -> int: "stage": "sharding", "sharding_source": ["heuristic"], "sharding_dims": ["ep"], + "support_partial_config": False, }, "sharding_transform_executor": { "stage": "sharding", @@ -119,6 +120,7 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> "detect_sharding": { "stage": "sharding", "sharding_source": ["heuristic"], + "support_partial_config": False, }, }, ) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 019a665278c..f75d9cd19b9 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -1,5 +1,6 @@ """Tests for basic graph sharding.""" +# add to the path directory 4 directories up from functools import partial from typing import Type @@ -200,6 +201,8 @@ def verify_local_weight_sizes(gm) -> bool: "detect_sharding": { "stage": "sharding", "sharding_source": sharding_source, + "support_partial_config": False, + "sharding_dims": ["tp"], }, "sharding_transform_executor": { "stage": "sharding", @@ -347,6 +350,8 @@ def _run_pattern_detection_job( "detect_sharding": { "stage": "sharding", "sharding_source": sharding_source, + "support_partial_config": False, + "sharding_dims": ["tp"], }, }, ) @@ -414,4 +419,4 @@ def test_sharding_pattern_detection( if __name__ == "__main__": - _run_pattern_detection_job(nn.Linear, False, 0, 8, False) + _run_pattern_detection_job(MLP, True, 0, 8, True) From 3d1eb435cd7c120f7c0c8b990a8c28ed9f75c1ab Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Sun, 5 Oct 2025 16:00:47 -0700 Subject: [PATCH 05/10] Fixed tp sharding test Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../auto_deploy/transform/library/sharding.py | 24 +++++----- .../library/test_tp_sharding.py | 44 ++++++++++++++++--- 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index de0aaaff616..c6aaed35e4d 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -174,11 +174,11 @@ def _apply( ) -> Tuple[GraphModule, TransformInfo]: local_rank, world_size = shared_config.local_rank, shared_config.world_size - # if world_size < 2: - # ad_logger.info("Skipping sharding for single device") - # return gm, TransformInfo( - # skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - # ) + if world_size < 2: + ad_logger.info("Skipping sharding for single device") + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) assert isinstance(gm, GraphModule), "Expecting GraphModule" shared_config.sharding_config.rank = local_rank @@ -218,30 +218,28 @@ def _apply( elif source == ShardingSource.CUSTOM: ad_logger.info("Applying sharding from custom config") - if shared_config.sharding_config.custom_sharding_config is None: + if sharding_config.custom_sharding_config is None: ad_logger.warning( "No custom sharding config found. Skipping sharding from custom config" ) continue - sharding_config.predefined_config["tp_plan"] = ( - shared_config.sharding_config.custom_sharding_config - ) + sharding_config.predefined_config = sharding_config.custom_sharding_config info += detect_sharding_from_factory_config(gm, sharding_config) elif source == ShardingSource.HEURISTIC: ad_logger.info( - f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}" + f"Running autodeploy sharding heuristics: {sharding_config.sharding_dims}" ) # run TP sharding across ranks - if ShardingDim.TP in shared_config.sharding_config.sharding_dims: + if ShardingDim.TP in sharding_config.sharding_dims: info += detect_column_row_shard(gm, sharding_config) # run EP sharding across ranks - if ShardingDim.EP in shared_config.sharding_config.sharding_dims: + if ShardingDim.EP in sharding_config.sharding_dims: info += detect_ep_shard(gm, sharding_config) # run BMM sharding across ranks - if ShardingDim.BMM in shared_config.sharding_config.sharding_dims: + if ShardingDim.BMM in sharding_config.sharding_dims: info += detect_dp_bmm_shard(gm, sharding_config) return gm, info diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index f75d9cd19b9..5697acee625 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import yaml from _dist_test_utils import get_device_counts from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm from _model_test_utils import FakeFP8Linear @@ -194,13 +195,19 @@ def verify_local_weight_sizes(gm) -> bool: op_expected = getattr(torch.ops.auto_deploy, dist_op_expected) gm = torch_export_to_gm(model, args=(x,), clone=True) - sharding_source = ["factory"] if from_config else ["heuristic"] + sharding_source = ["custom"] if from_config else ["heuristic"] + + if sharding_source == ["custom"]: + # write predefined_config to tp_sharding.yaml file + with open("tp_sharding.yaml", "w") as f: + yaml.dump(predefined_config, f, sort_keys=False) gm_transformed = InferenceOptimizer( None, { "detect_sharding": { "stage": "sharding", "sharding_source": sharding_source, + "custom_sharding_config": "tp_sharding.yaml", "support_partial_config": False, "sharding_dims": ["tp"], }, @@ -342,7 +349,34 @@ def _run_pattern_detection_job( ) ) - sharding_source = ["factory"] if from_config else ["heuristic"] + sharding_source = ["custom"] if from_config else ["heuristic"] + + if sharding_source == ["custom"]: + # write predefined_config to tp_sharding.yaml file + with open("tp_sharding.yaml", "w") as f: + yaml.dump(predefined_config, f, sort_keys=False) + InferenceOptimizer( + None, + { + "detect_sharding": { + "stage": "sharding", + "sharding_source": sharding_source, + "custom_sharding_config": "tp_sharding.yaml", + "support_partial_config": False, + "sharding_dims": ["tp"], + }, + "sharding_transform_executor": { + "stage": "sharding", + }, + }, + )(None, gm) + + sharding_source = ["custom"] if from_config else ["heuristic"] + + if sharding_source == ["custom"]: + # write predefined_config to tp_sharding.yaml file + with open("tp_sharding.yaml", "w") as f: + yaml.dump(predefined_config, f, sort_keys=False) # get detected transformations optimizer = InferenceOptimizer( None, @@ -350,6 +384,7 @@ def _run_pattern_detection_job( "detect_sharding": { "stage": "sharding", "sharding_source": sharding_source, + "custom_sharding_config": "tp_sharding.yaml", "support_partial_config": False, "sharding_dims": ["tp"], }, @@ -357,6 +392,7 @@ def _run_pattern_detection_job( ) optimizer.shared_config.local_rank = rank optimizer.shared_config.world_size = world_size + optimizer.shared_config.sharding_config.predefined_config = predefined_config _ = optimizer(None, gm) detected_transformations = optimizer.shared_config.sharding_config.tp_transforms @@ -416,7 +452,3 @@ def test_sharding_pattern_detection( No need to run distributed job, can be run on single process. """ _run_pattern_detection_job(model_cls, bias, 0, world_size, from_config) - - -if __name__ == "__main__": - _run_pattern_detection_job(MLP, True, 0, 8, True) From 0972ed6f6af4a4f6b54e580a05fab052138de3f7 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Wed, 8 Oct 2025 05:17:57 -0700 Subject: [PATCH 06/10] fixed llm args Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/llm_args.py | 6 ++++++ .../multigpu/transformations/library/test_tp_sharding.py | 6 +++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index c0ca5d896f2..3dd7c145cb2 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -145,6 +145,12 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): description="The fraction of available memory to allocate for cache.", ) + simple_shard_only: bool = Field( + default=False, + description="If True, force simple sharding (all_gather) in tensor parallelism. " + "If False, auto-detect and use column+row (all_reduce) sharding when possible.", + ) + compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = ( Field( default="torch-compile", diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 5697acee625..75ab445464d 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -408,10 +408,10 @@ def _run_pattern_detection_job( @pytest.mark.parametrize( "model_cls, dist_op_expected", ( - (MLP, "torch_dist_all_reduce"), - (FP8MLP, "torch_dist_all_reduce"), + # (MLP, "torch_dist_all_reduce"), + # (FP8MLP, "torch_dist_all_reduce"), (nn.Linear, "torch_dist_all_gather"), - (GQA_Block, "torch_dist_all_reduce"), + # (GQA_Block, "torch_dist_all_reduce"), ), ) def test_sharding( From 03f4190a3f91feaffa1c7c092cc931ba5a9c946c Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Wed, 8 Oct 2025 05:28:02 -0700 Subject: [PATCH 07/10] fixed llm args Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../multigpu/transformations/library/test_tp_sharding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 75ab445464d..5697acee625 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -408,10 +408,10 @@ def _run_pattern_detection_job( @pytest.mark.parametrize( "model_cls, dist_op_expected", ( - # (MLP, "torch_dist_all_reduce"), - # (FP8MLP, "torch_dist_all_reduce"), + (MLP, "torch_dist_all_reduce"), + (FP8MLP, "torch_dist_all_reduce"), (nn.Linear, "torch_dist_all_gather"), - # (GQA_Block, "torch_dist_all_reduce"), + (GQA_Block, "torch_dist_all_reduce"), ), ) def test_sharding( From 0cf4b2ae4884a9f6397989914dfc39635bfccbbf Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Sun, 12 Oct 2025 09:23:04 -0700 Subject: [PATCH 08/10] fixed filtered_nodes Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 83b52be6a90..321f326616b 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -244,6 +244,7 @@ def filtered_nodes( for t in target: if t(node): yield node + break else: # Handle the case where target or ops contains operations operations = ops if ops is not None else target From c46d7421dd39deef7343c96b1204e0ea94c7be62 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Mon, 13 Oct 2025 05:52:13 -0700 Subject: [PATCH 09/10] Fixed test_tp_sharding Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../library/test_tp_sharding.py | 38 +++++-------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 5697acee625..a71b9ffae13 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -1,6 +1,7 @@ """Tests for basic graph sharding.""" # add to the path directory 4 directories up +import os from functools import partial from typing import Type @@ -198,9 +199,10 @@ def verify_local_weight_sizes(gm) -> bool: sharding_source = ["custom"] if from_config else ["heuristic"] if sharding_source == ["custom"]: - # write predefined_config to tp_sharding.yaml file - with open("tp_sharding.yaml", "w") as f: - yaml.dump(predefined_config, f, sort_keys=False) + # If the file does not exist, write predefined_config to tp_sharding.yaml file + if not os.path.exists("tp_sharding.yaml"): + with open("tp_sharding.yaml", "w") as f: + yaml.dump(predefined_config, f, sort_keys=False) gm_transformed = InferenceOptimizer( None, { @@ -352,31 +354,11 @@ def _run_pattern_detection_job( sharding_source = ["custom"] if from_config else ["heuristic"] if sharding_source == ["custom"]: - # write predefined_config to tp_sharding.yaml file - with open("tp_sharding.yaml", "w") as f: - yaml.dump(predefined_config, f, sort_keys=False) - InferenceOptimizer( - None, - { - "detect_sharding": { - "stage": "sharding", - "sharding_source": sharding_source, - "custom_sharding_config": "tp_sharding.yaml", - "support_partial_config": False, - "sharding_dims": ["tp"], - }, - "sharding_transform_executor": { - "stage": "sharding", - }, - }, - )(None, gm) - - sharding_source = ["custom"] if from_config else ["heuristic"] + # If the file does not exist, write predefined_config to tp_sharding.yaml file + if not os.path.exists("tp_sharding.yaml"): + with open("tp_sharding.yaml", "w") as f: + yaml.dump(predefined_config, f, sort_keys=False) - if sharding_source == ["custom"]: - # write predefined_config to tp_sharding.yaml file - with open("tp_sharding.yaml", "w") as f: - yaml.dump(predefined_config, f, sort_keys=False) # get detected transformations optimizer = InferenceOptimizer( None, @@ -396,8 +378,6 @@ def _run_pattern_detection_job( _ = optimizer(None, gm) detected_transformations = optimizer.shared_config.sharding_config.tp_transforms - print(f"detected_transformations: {detected_transformations}") - print(f"expected_transformations: {expected_transformations}") # Run pattern detection test run_sharding_pattern_detection_test(detected_transformations, expected_transformations) From 8765bbe882dba28bef3fd8c4ed5bccb003ad879a Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Wed, 15 Oct 2025 04:40:35 -0700 Subject: [PATCH 10/10] merged TransformInfo changes Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../_torch/auto_deploy/transform/interface.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 9fcd05a2212..12958c12b1a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -132,6 +132,14 @@ class TransformInfo(BaseModel): "tensors in the graph and it preserves the has_valid_shapes flag of the last transform.", ) + @classmethod + def from_last_info(cls, info: "TransformInfo") -> "TransformInfo": + """Create a new TransformInfo from the last transform info.""" + return cls( + is_clean=info.is_clean, + has_valid_shapes=info.has_valid_shapes, + ) + # overload += operator to concatenate TransformInfo objects def __iadd__(self, other: "TransformInfo") -> "TransformInfo": # since TransformInfo is frozen, instead, we return a new TransformInfo object @@ -142,6 +150,24 @@ def __iadd__(self, other: "TransformInfo") -> "TransformInfo": has_valid_shapes=self.has_valid_shapes & other.has_valid_shapes, ) + def __or__(self, other: "TransformInfo") -> "TransformInfo": + """Merge two TransformInfo objects.""" + return TransformInfo( + skipped=self.skipped and other.skipped, # we only count skipped if both were skipped + num_matches=self.num_matches + other.num_matches, + is_clean=self.is_clean or other.is_clean, + has_valid_shapes=self.has_valid_shapes or other.has_valid_shapes, + ) + + def __and__(self, other: "TransformInfo") -> "TransformInfo": + """Merge two TransformInfo objects.""" + return TransformInfo( + skipped=self.skipped and other.skipped, # we only count skipped if both were skipped + num_matches=self.num_matches + other.num_matches, + is_clean=self.is_clean and other.is_clean, + has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes, + ) + TransformHistory = Dict[str, TransformInfo]