diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 13f1cf0703f..2c2e6dc761d 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -75,9 +75,9 @@ transforms: detect_sharding: stage: sharding simple_shard_only: false - use_sharding_from_factory: false - support_partial_config: false - sharding_dims: ['tp', 'ep', 'bmm'] + sharding_source: ['heuristic'] # ,'heuristic'] + support_partial_config: true + sharding_dims: ['ssm', 'tp', 'ep', 'bmm'] requires_shape_prop: true # TODO: (hg) need to ensure run_shape_prop after sharding. sharding_transform_executor: diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py index 396711bd80c..4ed4d2b036e 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py @@ -9,6 +9,7 @@ from einops import rearrange from transformers import AutoModelForCausalLM +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory from tensorrt_llm._torch.auto_deploy.models.patches.bamba import _bamba_mixer_torch_forward @@ -79,7 +80,7 @@ def _nemotron_h_block_forward( elif self.block_type == "attention": hidden_states = self.mixer(hidden_states, cache_position=cache_position) hidden_states = hidden_states[0] - elif self.block_type == "mlp": + elif self.block_type in ["mlp", "moe"]: hidden_states = self.mixer(hidden_states) else: raise ValueError(f"Invalid block_type: {self.block_type}") @@ -112,6 +113,19 @@ def get_model_from_config_patched(config, **kwargs): return model +def _set_sharding_config_patched(self, *args, **kwargs): + self._sharding_config["head_dim"] = 128 + self._sharding_config["tp_plan"] = { + "in_proj": "mamba", + "out_proj": "rowwise", + # "up_proj": "colwise", + # "down_proj": "rowwise", + # "*": "gather", + } + + +AutoModelForCausalLMFactory._set_sharding_config = _set_sharding_config_patched + # TODO: figure out how this can be incorporated into the export patch system AutoModelForCausalLM.from_config = get_model_from_config_patched diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 0bd28a1d78d..81140fccb1b 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -173,6 +173,15 @@ def __and__(self, other: "TransformInfo") -> "TransformInfo": has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes, ) + # implement + addition operator for TransformInfo + def __add__(self, other: "TransformInfo") -> "TransformInfo": + return TransformInfo( + skipped=self.skipped and other.skipped, + num_matches=self.num_matches + other.num_matches, + is_clean=self.is_clean and other.is_clean, + has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes, + ) + TransformHistory = Dict[str, TransformInfo] diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 669357b1399..6980b5dadc3 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -294,7 +294,7 @@ def _find_final_hidden_state_node( if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2): return None index_node = mul_node.args[1] - index_add_node = bfs( + index_add_node, _ = bfs( index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary ) if not index_add_node: @@ -360,7 +360,7 @@ def target(n: torch.fx.Node) -> bool: return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0 try: - node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary) + node_to_remove, _ = bfs(start_boundary, target, attr_next="users", boundary=end_boundary) graph.erase_node(node_to_remove) return True except RuntimeError: @@ -430,7 +430,7 @@ def _apply( common_ancessor2 = _find_lowest_common_ancessor(arg2_list) if not common_ancessor2: continue - selected_experts = bfs( + selected_experts, _ = bfs( common_ancessor2, lambda node: is_op(node, torch.ops.aten.one_hot), attr_next="all_input_nodes", diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py index 477cde8e02d..e04f3212722 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py @@ -13,7 +13,7 @@ from ...shim.interface import CachedSequenceInterface from ...utils.cuda_mem_tracker import cuda_memory_tracker from ...utils.logger import ad_logger -from ...utils.node_utils import extract_param_names_from_lin_node, is_linear_op, is_op +from ...utils.node_utils import extract_param_names_from_node, is_linear_op, is_op from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry @@ -36,7 +36,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node y2 = y[:, out1:out1+out2] """ # some info we need - keys_unfused = [extract_param_names_from_lin_node(n)[0] for n in linear_nodes] + keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes] params_unfused = [gm.get_parameter(k) for k in keys_unfused] sizes_unfused = [p.size(0) for p in params_unfused] key_fused = f"fused_weight_{idx}" @@ -128,7 +128,7 @@ def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple def _insert_fused_quant_gemm( self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node] ): - keys_unfused = [extract_param_names_from_lin_node(n)[0] for n in linear_nodes] + keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes] params_unfused = [gm.get_parameter(k) for k in keys_unfused] sizes_unfused = [p.size(0) for p in params_unfused] key_fused = f"fused_weight_{idx}" diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 94137e9a0b1..9f53a9bd637 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -14,7 +14,7 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import ( - extract_param_names_from_lin_node, + extract_param_names_from_node, get_quantization_params_from_linear_node, is_bmm_op, is_linear_op, @@ -136,7 +136,7 @@ def _insert_quantized_linear( The state_dict is also updated to contain the sharded weights. """ - param_name, _ = extract_param_names_from_lin_node(node) + param_name, _ = extract_param_names_from_node(node) original_weight = gm.get_parameter(param_name) new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False) modname, _, attrname = param_name.rpartition(".") diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index d6de54b22d2..25b2330c38a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -29,19 +29,26 @@ from ...shim.interface import CachedSequenceInterface from ...utils.logger import ad_logger from ...utils.node_utils import ( + bfs, filtered_nodes, identify_regions_between_residuals, is_fake_quantized_linear_op, is_linear_op, is_op, + subgraph, ) from ...utils.sharding_utils import ( BMMShardingInfo, EPShardingInfo, + LayerType, + ParameterUpdateInfo, ShardingConfig, + ShardingDim, + ShardingSource, ShardingTransformInfo, SplitDimension, - TPShardingInfo, + WeightShardingInfo, + get_all_weights_in_subgraph, ) from ..interface import ( BaseTransform, @@ -82,7 +89,7 @@ def check_and_apply(transform: ShardingTransformInfo) -> bool: return transform.check_and_apply(gm, node_dict[transform.target_node]) num_matches = 0 - for tp_transform in shared_config.sharding_config.tp_transforms: + for tp_transform in shared_config.sharding_config.weight_sharding_transforms: if check_and_apply(tp_transform): num_matches += 1 for bmm_transform in shared_config.sharding_config.bmm_transforms: @@ -92,44 +99,56 @@ def check_and_apply(transform: ShardingTransformInfo) -> bool: if check_and_apply(ep_transform): num_matches += 1 + # post-sharding cleanup transformations + for update_transform in shared_config.sharding_config.parameter_update_transforms: + if not check_and_apply(update_transform): + ad_logger.warning(f"Invalid parameter update transformation {update_transform}.") + info = TransformInfo( skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False ) + # exit() return gm, info -def _append_simple_shard( +def _process_simple_shard( nodes_linear: Dict[Node, List[Node]], rank: int, world_size: int, sharding_config: ShardingConfig, -) -> None: +) -> int: # for every linear node: # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) - tp_shards: List[TPShardingInfo] = [] + num_simple_shards = 0 for node_group in nodes_linear.values(): for n in node_group: - tp_shards.append( - TPShardingInfo.from_node( - n, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, + num_simple_shards += int( + sharding_config.add( + WeightShardingInfo.from_node( + n, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) ) ) - sharding_config.tp_transforms.extend(tp_shards) + return num_simple_shards class ShardingTransformConfig(TransformConfig): """Configuration for sharding transformations.""" simple_shard_only: bool = Field(default=False) - use_sharding_from_factory: bool = Field(default=False) + sharding_source: List[ShardingSource] = Field( + default_factory=lambda: [ShardingSource.HEURISTIC] + ) support_partial_config: bool = Field(default=False) - # Which sharding families to run: any subset of {"tp", "ep", "bmm"} - sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"]) + # Which sharding dimensions to run: any subset of {"tp", "ep", "bmm"} + sharding_dims: List[ShardingDim] = Field( + default_factory=lambda: [ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM] + ) @TransformRegistry.register("detect_sharding") @@ -165,6 +184,7 @@ def _apply( shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: local_rank, world_size = shared_config.local_rank, shared_config.world_size + # world_size = 2 if world_size < 2: ad_logger.info("Skipping sharding for single device") @@ -173,73 +193,310 @@ def _apply( ) assert isinstance(gm, GraphModule), "Expecting GraphModule" - shared_config.sharding_config.rank = local_rank - shared_config.sharding_config.world_size = world_size - shared_config.sharding_config.predefined_config = ( - factory.get_sharding_config() if factory else {} - ) - shared_config.sharding_config.factory_source = ( - shared_config.sharding_config.predefined_config.get( - "source", ShardingConfigSource.UNKNOWN - ) + sharding_config = shared_config.sharding_config + sharding_config.rank = local_rank + sharding_config.world_size = world_size + sharding_config.predefined_config = factory.get_sharding_config() if factory else {} + sharding_config.factory_source = ( + sharding_config.predefined_config.get("source", ShardingConfigSource.UNKNOWN) if factory else ShardingConfigSource.UNKNOWN ) - shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only - shared_config.sharding_config.support_partial_config = self.config.support_partial_config - shared_config.sharding_config.sharding_dims = self.config.sharding_dims + sharding_config.simple_shard_only = self.config.simple_shard_only + sharding_config.support_partial_config = self.config.support_partial_config + sharding_config.sharding_dims = self.config.sharding_dims + sharding_config.sharding_source = self.config.sharding_source + + sharding_config.validate_config() + + info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) + for source in sharding_config.sharding_source: + if source == ShardingSource.FACTORY: + if len(sharding_config.get_predefined_config()) == 0: + ad_logger.warning( + "No factory config found. Skipping sharding from factory config" + ) + continue + ad_logger.info("Applying sharding from factory config") + info += detect_sharding_from_factory_config(gm, sharding_config) + + elif source == ShardingSource.HEURISTIC: + ad_logger.info( + f"Running autodeploy sharding heuristics: {sharding_config.sharding_dims}" + ) + if ShardingDim.SSM in sharding_config.sharding_dims: + info += detect_ssm_shard(gm, sharding_config) + + # run TP sharding across ranks + if ShardingDim.TP in sharding_config.sharding_dims: + info += detect_column_row_shard(gm, sharding_config) + + # run EP sharding across ranks + if ShardingDim.EP in sharding_config.sharding_dims: + info += detect_ep_shard(gm, sharding_config) - shared_config.sharding_config.use_sharding_from_factory = ( - self.config.use_sharding_from_factory + # run BMM sharding across ranks + if ShardingDim.BMM in sharding_config.sharding_dims: + info += detect_dp_bmm_shard(gm, sharding_config) + + return gm, info + + +def _process_ssm_sharding( + gm: GraphModule, + entry_node: Node, + sharding_config: ShardingConfig, + rank: int, + world_size: int, + min_local_shape: int = 1, +) -> int: + """ + Process the SSM sharding from the candidate nodes and update the view and split nodes accordingly. + """ + # Find next linear node to define subgraph boundary + try: + out_proj_node, _ = bfs(entry_node, is_linear_op, include_root=False) + except RuntimeError: + ad_logger.warning("Could not find next linear node after entry_node for Mamba sharding") + return 0 + + # Get subgraph between entry_node and next linear node + subgraph_nodes = subgraph([entry_node], [out_proj_node]) + + ############################################################## + ########## infer split sizes for in_proj and conv1d ########## + ############################################################## + # in_proj and conv1d are fused, followed up by split nodes. Infer split sizes: + split_nodes = [ + n + for n in subgraph_nodes + if is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]) + ] + if len(split_nodes) != 2: + ad_logger.warning( + f"Subgraph does not contain exactly two split nodes. " + f"Skipping Mamba sharding. split_nodes={split_nodes}" ) + return 0 + split_sizes_0 = split_nodes[0].args[1] + split_sizes_1 = split_nodes[1].args[1] + if split_sizes_0[1] != sum(split_sizes_1): + ad_logger.warning( + f"Split nodes have different sizes. " + f"Skipping Mamba sharding. split_sizes_1={split_sizes_0}, split_sizes_2={split_sizes_1}" + ) + return 0 + fused_weight_dims = { + "in_proj": split_sizes_0[0:1] + split_sizes_1 + split_sizes_0[2:], + "conv1d": split_sizes_1, + } - sharding_config = shared_config.sharding_config - sharding_config.validate_config() + # # ############################################################## + # # ############## update split nodes ############################ + # # ############################################################## + split_args_0 = list(split_nodes[0].args) + split_args_0[1] = [s // world_size for s in split_args_0[1]] + split_args_1 = list(split_nodes[1].args) + split_args_1[1] = [s // world_size for s in split_args_1[1]] + sharding_config.add( + ParameterUpdateInfo( + rank=rank, + world_size=world_size, + target_node=split_nodes[0].name, + args=tuple(split_args_0), + ) + ) + sharding_config.add( + ParameterUpdateInfo( + rank=rank, + world_size=world_size, + target_node=split_nodes[1].name, + args=tuple(split_args_1), + ) + ) - if ( - shared_config.sharding_config.use_sharding_from_factory - and len(shared_config.sharding_config.get_predefined_config()) > 0 - ): - ad_logger.info("Applying sharding from config") - factory_info = detect_sharding_from_factory_config(gm, sharding_config) - return gm, factory_info + # ############################################################## + # ############# update conv1d num output channels ############## + # ############################################################## + conv1d_nodes = [ + n + for n in subgraph_nodes + if is_op(n, [torch.ops.aten.conv1d, torch.ops.auto_deploy.torch_causal_conv1d]) + ] + assert len(conv1d_nodes) == 1, "Expecting exactly one conv1d node" + conv1d_node = conv1d_nodes[0] + # conv1d_node last argument is the number of output channels. + # This one is also sharded, so we need to update this parameter + conv_args = list(conv1d_node.args) + conv_args[-1] = conv1d_node.args[-1] // world_size + sharding_config.add( + ParameterUpdateInfo( + rank=rank, world_size=world_size, target_node=conv1d_node.name, args=tuple(conv_args) + ) + ) - ad_logger.info( - f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}" + # ############################################################## + # ####### shard the entry_node (the first linear layer) ######## + # ############################################################## + sharding_config.add( + WeightShardingInfo.from_node( + entry_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + fused_weight_dims=fused_weight_dims["in_proj"], ) - # run TP sharding across ranks - if "tp" in shared_config.sharding_config.sharding_dims: - tp_info = detect_column_row_shard(gm, sharding_config) - else: - tp_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + ) - # run EP sharding across ranks - if "ep" in shared_config.sharding_config.sharding_dims: - ep_info = detect_ep_shard(gm, sharding_config) - else: - ep_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + # ############################################################## + # ############## shard the remaining weights ################### + # ############################################################## + # # Get all weight nodes in the subgraph except for out_proj (it has to be row-sharded) + weight_nodes = [ + n + for n in get_all_weights_in_subgraph([entry_node], [out_proj_node]) + if "out_proj" not in str(n) + ] + for weight_node in weight_nodes: + weight_key = weight_node.target + # Get the weight parameter + try: + gm.get_parameter(weight_key) + except AttributeError: + ad_logger.debug(f"Could not get parameter for {weight_key}, skipping") + continue + + # Get fused dims for this weight if specified + fused_dims = None + for k, v in fused_weight_dims.items(): + if k in weight_key: + fused_dims = v + break + + # Shard the weight tensor (also updates the parameter in the module) + sharding_config.add( + WeightShardingInfo.from_node( + list(weight_node.users)[0], + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + fused_weight_dims=fused_dims, ) + ) - # run BMM sharding across ranks - if "bmm" in shared_config.sharding_config.sharding_dims: - dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config) - else: - dp_bmm_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + # ############################################################## + # ############## update the view and reshape nodes ############# + # ############################################################## + nodes_to_validate = [ + n for n in subgraph_nodes if is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]) + ] + for view_node in nodes_to_validate: + if len(view_node.args) < 2: + continue + view_shape = list(view_node.args[1]) + if not isinstance(view_shape, list): + continue + if len(view_shape) >= 3 and isinstance(view_shape[2], int) and view_shape[2] != -1: + args = list(view_node.args) + view_shape[2] = view_shape[2] // world_size + args[1] = tuple(view_shape) + sharding_config.add( + ParameterUpdateInfo( + rank=rank, world_size=world_size, target_node=view_node.name, args=tuple(args) + ) ) + ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}") + + ############################################################## + ############## shard the out_proj node ####################### + ############################################################## + sharding_config.add( + WeightShardingInfo.from_node( + out_proj_node, + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_reduce", + ) + ) + return 1 - info = TransformInfo( - skipped=tp_info.skipped and ep_info.skipped and dp_bmm_info.skipped, - num_matches=tp_info.num_matches + ep_info.num_matches + dp_bmm_info.num_matches, - is_clean=tp_info.is_clean and ep_info.is_clean and dp_bmm_info.is_clean, - has_valid_shapes=tp_info.has_valid_shapes - and ep_info.has_valid_shapes - and dp_bmm_info.has_valid_shapes, + +def _process_column_sharding( + gm: GraphModule, + linear_nodes: List[Node], + sharding_config: ShardingConfig, + rank: int, + world_size: int, + min_local_shape: int = 1, + fused_weight: bool = False, +) -> None: + """ + Parse the column sharding from the candidate nodes and update the view and split nodes accordingly. + """ + for linear_node in linear_nodes: + sharding_config.add( + WeightShardingInfo.from_node( + linear_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, # for column sharding, no dist op is performed + min_local_shape=min_local_shape, + ) ) - return gm, info + + # get the subgraph of this module. Subgraph boundary is the next linear node. + next_lin_node, depth = bfs(linear_nodes[0], is_linear_op, include_root=False) + subgraph_nodes = subgraph( + [linear_nodes], + [next_lin_node], + ) + + nodes_to_validate = [ + n for n in subgraph_nodes if is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]) + ] + for view_node in nodes_to_validate: + if len(view_node.args) < 2: + continue + view_shape = list(view_node.args[1]) + if not isinstance(view_shape, list): + continue + if len(view_shape) >= 3 and isinstance(view_shape[2], int) and view_shape[2] != -1: + args = list(view_node.args) + view_shape[2] = view_shape[2] // world_size + args[1] = tuple(view_shape) + sharding_config.add( + ParameterUpdateInfo( + rank=rank, world_size=world_size, target_node=view_node.name, args=tuple(args) + ) + ) + ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}") + + # if fused_weight_dims is provided, we need to update all split sizes + if fused_weight: + assert len(linear_nodes) == 1, "Fused weight should be only one linear node" + node = linear_nodes[0] + assert world_size is not None, "World size is required to update the split node params" + assert len(node.users) == 1, "Fused linear node should have only one user: a split node" + user = list(node.users)[0] + if is_op(user, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]): + orig_sizes = user.args[1] + new_sizes = [orig_sizes[i] // world_size for i in range(len(orig_sizes))] + args = list(user.args) + args[1] = new_sizes + sharding_config.add( + ParameterUpdateInfo( + rank=rank, world_size=world_size, target_node=user.name, args=tuple(args) + ) + ) + ad_logger.debug( + f"\nInserted parameter update transformation for split node {user} arguments to {user.args}" + ) def detect_sharding_from_factory_config( @@ -260,6 +517,7 @@ def detect_sharding_from_factory_config( # 4. the allowed values are: # - "colwise" # - "rowwise" + # - "mamba" # - "sequence_parallel" # - "local_colwise" # - "local_rowwise" @@ -314,8 +572,8 @@ def detect_sharding_from_factory_config( # we have a match. Get the config for this layer config = tp_plan[key] if config == "colwise": - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.COLUMN, rank=rank, @@ -324,9 +582,10 @@ def detect_sharding_from_factory_config( min_local_shape=min_local_shape, ) ) + num_row_col_shards += 1 elif config == "rowwise": - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.ROW, rank=rank, @@ -336,6 +595,19 @@ def detect_sharding_from_factory_config( ) ) num_row_col_shards += 1 + elif config == "mamba": + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + layer_type=LayerType.MAMBA, + ) + ) + num_row_col_shards += 1 elif "sequence" in config: # TODO: Sequence parallelism is not supported yet. ad_logger.warning("Sequence parallelism is not supported yet. Skipping.") @@ -345,8 +617,8 @@ def detect_sharding_from_factory_config( if "shared" in module_name: col_row_action = config.replace("local_", "") if col_row_action == "colwise": - sharding_config.tp_transforms.append( - TPShardingInfo( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo( target_node=lin_node.name, split_dim=SplitDimension.COLUMN, rank=rank, @@ -356,8 +628,8 @@ def detect_sharding_from_factory_config( ) ) elif col_row_action == "rowwise": - sharding_config.tp_transforms.append( - TPShardingInfo( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo( target_node=lin_node.name, split_dim=SplitDimension.ROW, rank=rank, @@ -375,8 +647,8 @@ def detect_sharding_from_factory_config( elif "gather" in config: # Simple shard (row + all_gather) - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.COLUMN, rank=rank, @@ -390,8 +662,8 @@ def detect_sharding_from_factory_config( ad_logger.warning( f"Unsupported sharding action {config}. Fallback to simple shard" ) - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.COLUMN, rank=rank, @@ -408,7 +680,7 @@ def detect_sharding_from_factory_config( f"row-col pattern: {num_row_col_shards})" ) - num_matches = len(sharding_config.tp_transforms) + num_matches = len(sharding_config.weight_sharding_transforms) if sharding_config.support_partial_config: ad_logger.info( @@ -441,6 +713,42 @@ def detect_sharding_from_factory_config( ) +def detect_ssm_shard( + gm: GraphModule, + sharding_config: ShardingConfig, +) -> TransformInfo: + """A transformation to apply sharding to the model following SSM parallelism. + TODO: This is a TEMPORARY place for this logic due to the incompatibility between the + identify_regions_between_residuals() and subgraph() methods to detect layers. + The goal is to have a unified single pass over the graph to detect layers and apply + appropriate sharding transformations. + """ + rank, world_size = sharding_config.rank, sharding_config.world_size + if world_size < 2: + ad_logger.info("Skipping TP sharding for single device") + return TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) + ad_logger.info("Running SSM sharding detection") + + # find all ssm nodes in the graph + ssm_nodes = filtered_nodes(gm.graph.nodes, ops=torch.ops.auto_deploy.torch_ssm) + # ssm_nodes = list(ssm_nodes)[1:2] + num_ssm_shards = 0 + for ssm_node in ssm_nodes: + # We assume that one ssm node defines a subgraph corresponding + # to a single Mamba layer. + # Find defining previous (in_proj) and next (out_proj) linear nodes. + in_proj_node, _ = bfs(ssm_node, is_linear_op, attr_next="args", include_root=False) + + num_ssm_shards += int( + _process_ssm_sharding(gm, in_proj_node, sharding_config, rank, world_size) + ) + + ad_logger.info(f"Found {num_ssm_shards} SSM shards") + return TransformInfo( + skipped=False, num_matches=num_ssm_shards, is_clean=False, has_valid_shapes=False + ) + + def detect_column_row_shard( gm: GraphModule, sharding_config: ShardingConfig, @@ -545,15 +853,17 @@ def detect_column_row_shard( if sharding_config.simple_shard_only: ad_logger.debug(f"Forcing Simple Shard: Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += _process_simple_shard( + nodes_linear, rank, world_size, sharding_config + ) continue # simple shard when we have != 2 groups of linear nodes if len(nodes_linear) != 2: ad_logger.debug(f"Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += _process_simple_shard( + nodes_linear, rank, world_size, sharding_config + ) continue # let's look at the unnacounted nodes. They are okay as long as they fall before the @@ -583,8 +893,9 @@ def detect_column_row_shard( # check if any unaccounted nodes are left. If so, do a simply shard if unaccounted_nodes or attention_related_nodes: ad_logger.debug(f"Unaccounted nodes: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += _process_simple_shard( + nodes_linear, rank, world_size, sharding_config + ) continue # If we can account for all sharded nodes, we can do a two-way shard @@ -596,30 +907,49 @@ def detect_column_row_shard( # Column-row shard boundary region detection is probably wrong - there should be # only one attention operation. Fall back to simple shard. ad_logger.debug(f"More than one attention node: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += _process_simple_shard( + nodes_linear, rank, world_size, sharding_config + ) continue # Extract head dimension. We cannot shard below the head_dim size. # Assume that head_dim is the last (innermost) dimension of the tensor min_local_shape = attention_nodes.pop().meta["val"].shape[-1] else: min_local_shape = 1 - for i, group in enumerate(nodes_linear.values()): - for n in group: - if i > 0: - dist_op = "all_reduce" - else: - dist_op = None - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - n, - split_dim=i, - rank=rank, - world_size=world_size, - dist_op=dist_op, - min_local_shape=min_local_shape, - ) - ) + + # We are inserting column-row shard for each group of linear enodes + # This may require parameter update of nodes whose args depend on (sharded) dimensions, + # such as view or split nodes. + nodes_to_column_shard = nodes_linear.values()[0] + nodes_to_row_shard = nodes_linear.values()[1] + if len(nodes_to_row_shard) != 1: + ad_logger.warning( + "Expecting only one linear node for row sharding, but got %s", + len(nodes_to_row_shard), + ) + num_simple_shards += _process_simple_shard( + nodes_linear, rank, world_size, sharding_config + ) + continue + + # column-row sharding + weight_sharding_transforms, parameter_update_transforms = _process_column_sharding( + gm, nodes_to_column_shard, rank, world_size, min_local_shape + ) + sharding_config.weight_sharding_transforms.extend(weight_sharding_transforms) + sharding_config.parameter_update_transforms.extend(parameter_update_transforms) + + # shard single row node + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( + nodes_to_row_shard[0], + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_reduce", + ) + ) + num_row_col_shards += 1 ad_logger.info( diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index bc454e69396..c7e2b637f8f 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -106,10 +106,10 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node): return input_params, weight_params, output_params -def extract_weight_node(mm_node: Node) -> int: - """Extracts the weight node from the given linear or BMM node. We assume torch.bmm(activation, weight)""" +def extract_weight_node(node: Node) -> int: + """Extracts the weight node from the given parametrized node""" - def find_get_attr_node(node: Node) -> Node: + def find_get_attr_node(weight_node: Node) -> Node: """Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op.""" # If node is a get_attr node return node # List of nodes allowed in between a get_attr node and the matmul node @@ -118,40 +118,47 @@ def find_get_attr_node(node: Node) -> Node: torch.ops.aten.view.default, } - if node.op == "get_attr": - return node + if weight_node.op == "get_attr": + return weight_node # If node is not in the list of allowable ops then return None - if node.target not in allowed_ops: + if weight_node.target not in allowed_ops: return None - for input_node in node.all_input_nodes: + for input_node in weight_node.all_input_nodes: result = find_get_attr_node(input_node) if result: return result return None - weight_node = mm_node.args[1] + if is_op(node, torch.ops.aten.bmm): + weight_node = node.args[1] + # for other parametrized nodes, we need to find the weight node + else: + weight_nodes = [n for n in node.args if isinstance(n, Node) and n.op == "get_attr"] + # can be two weights (if bias weight is present) + assert len(weight_nodes) >= 1, "Expected exactly one weight node in the parametrized node" + weight_node = weight_nodes[0] # for modelopt quantized graph, there will be a quantize_op - _, weight_params, _ = get_quantization_params_from_linear_node(mm_node) + _, weight_params, _ = get_quantization_params_from_linear_node(node) weight_node = weight_params.input_node if weight_params else weight_node return find_get_attr_node(weight_node) -def num_users_of_weight_node(mm_node: Node) -> int: - """Returns the number of users of the weight node of the given matmul node.""" - weight_node = extract_weight_node(mm_node) +def num_users_of_weight_node(node: Node) -> int: + """Returns the number of users of the weight node of the given parametrized node.""" + weight_node = extract_weight_node(node) return len(weight_node.users) if weight_node is not None else 0 -def extract_param_names_from_lin_node(mm_node: Node) -> Tuple[str, Optional[str]]: - """Extracts the name of the parameter associated with the given matmul node. +def extract_param_names_from_node(node: Node) -> Tuple[str, Optional[str]]: + """Extracts the name of the parameter associated with the given parametrized node. Args: - mm_node: Matmul node in the graph. + node: node with weight parameters in the graph. """ - weight_node = extract_weight_node(mm_node) + weight_node = extract_weight_node(node) assert weight_node, "Cannot identify weight parameter of linear node." @@ -159,7 +166,14 @@ def extract_param_names_from_lin_node(mm_node: Node) -> Tuple[str, Optional[str] weight_name = weight_node.target # check for bias - bias_node = mm_node.args[2] if len(mm_node.args) > 2 else None + if is_op(node, torch.ops.aten.bmm): + bias_node = node.args[2] if len(node.args) > 2 else None + else: + weight_nodes = [n for n in node.args if isinstance(n, Node) and n.op == "get_attr"] + if len(weight_nodes) > 1: + bias_node = weight_nodes[1] + else: + bias_node = None assert bias_node is None or bias_node.op == "get_attr" bias_name = bias_node.target if bias_node is not None else None @@ -365,23 +379,43 @@ def identify_regions_between_residuals(gm: GraphModule) -> List[Node]: return boundary_nodes +def identify_layer_subgraphs(gm: GraphModule) -> None: + pass + + def bfs( - node: Node, target: Callable, attr_next: str = "users", boundary: Optional[Node] = None -) -> Node: - queue = [node] + node: Node, + target: Callable, + attr_next: str = "users", + boundary: Optional[Node] = None, + include_root: bool = True, +) -> Tuple[Node, int]: + """ + Breadth-first search of the graph. + Returns the found node and the depth of the node. + """ + depth = 0 + queue_at_depth = [node] + queue_at_depth_next = [] visited = set() - while queue: - cur_node = queue.pop(0) + while queue_at_depth or queue_at_depth_next: + cur_node = queue_at_depth.pop(0) if boundary is not None and cur_node == boundary: continue # Skip the boundary node. - if target(cur_node): - return cur_node - for next_node in getattr(cur_node, attr_next): - if boundary is not None and next_node == boundary: - continue # Do not expand past the boundary. - if next_node not in visited: - visited.add(next_node) - queue.append(next_node) + if target(cur_node) and (include_root or depth > 0): + return cur_node, depth + if hasattr(cur_node, attr_next): + for next_node in getattr(cur_node, attr_next): + if boundary is not None and next_node == boundary: + continue # Do not expand past the boundary. + if next_node not in visited: + visited.add(next_node) + queue_at_depth_next.append(next_node) + if not queue_at_depth: + queue_at_depth = queue_at_depth_next + queue_at_depth_next = [] + depth += 1 + raise RuntimeError(f"Could not find node with target condition {target}.") @@ -456,19 +490,19 @@ def predecessors( If exclude is provided, exclude nodes that satisfy the condition. """ preds = [] + seen = set() for arg in node.args: if isinstance(arg, Node): + if ((not include) or (include and include(arg))) and (not exclude or not exclude(arg)): + if arg not in seen: + preds.append(arg) + seen.add(arg) if depth > 1: - preds.extend(predecessors(arg, depth - 1, include, exclude)) - # add node arg if either: - # a) include and exclude are not specified - # b) include is specified and arg satisfies include condition - # c) exclude is specified and arg does not satisfy exclude condition - if exclude and exclude(arg): - continue - if (not include) or (include and include(arg)): - preds.append(arg) - return list(reversed(preds)) + for p in predecessors(arg, depth - 1, include, exclude): + if p not in seen: + preds.append(p) + seen.add(p) + return preds def successors( @@ -483,12 +517,83 @@ def successors( If exclude is provided, exclude nodes that satisfy the condition. """ succs = [] + seen = set() for user in node.users: + if ((not include) or (include and include(user))) and (not exclude or not exclude(user)): + if user not in seen: + succs.append(user) + seen.add(user) if depth > 1: - succs.extend(successors(user, depth - 1, include, exclude)) - # analogous logic to predecessors - if exclude and exclude(user): + for s in successors(user, depth - 1, include, exclude): + if s not in seen: + succs.append(s) + seen.add(s) + return succs + + +def subgraph( + sources: list[Node], + sinks: list[Node], + include_boundary_nodes: bool = True, + include: Optional[Callable[[Node], bool]] = None, + exclude: Optional[Callable[[Node], bool]] = None, +) -> List[Node]: + """ + Returns a list of nodes in a subgraph in computation DAG defined as all nodes + succeeding any of the node in sources and preceding any of the nodes in sinks. + It is built by a BFS traversal from sinks, where the sources list acts as a + boundary. We do it in this order (and not from sources to sinks) to include + nodes like weights or other inputs (they are not successors of sinks, so otherwise + they wouldn't be included). + + Optionally, include or exclude conditions may be specified to include [exclude] + only nodes that meet [don't meet] certain condition. + """ + subgraph_nodes = [] + seen = set() + queue = list(sinks) + sources_set = set(sources) + + # Initialize queue with sinks and mark them as seen + for node in sinks: + if node not in seen: + seen.add(node) + + # BFS traversal from sinks backwards through predecessors + while queue: + node = queue.pop(0) + + # Check if node should be included based on filters + should_include = True + if include is not None and not include(node): + should_include = False + if exclude is not None and exclude(node): + should_include = False + if not include_boundary_nodes and (node in sources_set) or (node in sinks): + should_include = False + + if should_include: + subgraph_nodes.append(node) + + # Stop traversal at source nodes (boundary) - don't explore their predecessors + if node in sources_set: continue - if (not include) or (include and include(user)): - succs.append(user) - return list(reversed(succs)) + + # Traverse to predecessor nodes (all inputs to this node) + for arg in node.args: + if isinstance(arg, Node) and arg not in seen: + seen.add(arg) + queue.append(arg) + + return subgraph_nodes + + +def draw_graph(gm: GraphModule, filename: str): + """ + Dump graphmodule to SVG file using PyTorch's built-in drawer. + """ + from torch.fx.passes.graph_drawer import FxGraphDrawer + + drawer = FxGraphDrawer(gm, filename) + with open(f"{filename}.svg", "wb") as f: + f.write(drawer.get_dot_graph().create_svg()) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py index 90e6b380338..aee98c37713 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py @@ -8,7 +8,7 @@ from ..custom_ops.quant import FP4_GLOBAL_SCALE_MAX, FP8_MAX from .logger import ad_logger from .node_utils import ( - extract_param_names_from_lin_node, + extract_param_names_from_node, get_quantization_params_from_linear_node, is_bmm_op, is_linear_op, @@ -117,7 +117,7 @@ def should_skip_quantization( else: if not (is_linear_op(node_or_name) or is_bmm_op(node_or_name)): return True - param_name, _ = extract_param_names_from_lin_node(node_or_name) + param_name, _ = extract_param_names_from_node(node_or_name) modname, _, _ = param_name.rpartition(".") return any(fnmatch(modname, pattern) for pattern in excluded_patterns) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index a95045b7d28..1ee817b0d30 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -2,10 +2,11 @@ import math import operator +import re from abc import ABC, abstractmethod -from enum import IntEnum +from enum import Enum, IntEnum from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple import torch import torch.nn as nn @@ -14,7 +15,14 @@ from ..models.factory import ShardingConfigSource from ..utils.logger import ad_logger -from .node_utils import extract_param_names_from_lin_node, is_op, num_users_of_weight_node +from .node_utils import ( + bfs, + extract_param_names_from_node, + is_linear_op, + is_op, + num_users_of_weight_node, + subgraph, +) from .quantization_utils import ( cutlass_fp4_scale_to_modelopt_fp4_scale, modelopt_fp4_scale_to_cutlass_fp4_scale, @@ -39,7 +47,9 @@ def _load_hook( if key not in state_dict: return p_to_load = state_dict[key] + p_to_load = p_to_load if param_shape == p_to_load.shape else f_split(p_to_load) + state_dict[key] = p_to_load @@ -54,43 +64,93 @@ def _load_hook_remove( state_dict.pop(key, None) -def _update_view_nodes(node: Node) -> None: +def _validate_sharded_shapes( + node: Node, fused_weight_dims: Optional[list] = None, world_size: int = None +) -> None: """ - After sharding weights of the linear node, using column split + Update the shapes of the view nodes and the split node parameters to account for the TP sharding. + 1. After sharding weights of the linear node using column split in attention module (Q, K, V), - the output Y = X @ W^T is [batch, seq, num_heads // TP_size, head_dim] - Some models hardcode the shape of the output to be [batch, seq, num_heads, head_dim] + the output Y = X @ W^T shape is [batch, seq, num_heads // TP_size, head_dim]. + Some models hardcode the shape of the output to [batch, seq, num_heads, head_dim] instead of implicit [batch, seq, -1, head_dim]. Detect such cases and update the shape of the view node accordingly. + 2. If the weights are fused (e.g,. QKV, gate_up, SSM, etc.), the follow-up split node parameters + need to be updated to account for the TP sharding. """ - view_nodes = [n for n in node.users if is_op(n, torch.ops.aten.view)] - for view_node in view_nodes: - view_shape = view_node.args[1] - if len(view_shape) == 4 and view_shape[2] != -1: + + # get the subgraph of this module. Subgraph boundary is the next linear node. + next_lin_node, depth = bfs(node, is_linear_op, include_root=False) + nodes_to_validate = subgraph( + [node], + [next_lin_node], + include=lambda n: is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]), + ) + for view_node in nodes_to_validate: + if len(view_node.args) < 2: + continue + if "sharded" in view_node.meta and view_node.meta["sharded"]: + continue + view_shape = list(view_node.args[1]) + if not isinstance(view_shape, list): + continue + if len(view_shape) >= 3 and isinstance(view_shape[2], int) and view_shape[2] != -1: args = list(view_node.args) - args[1] = [view_shape[0], view_shape[1], -1, view_shape[3]] + view_shape[2] = view_shape[2] // world_size + args[1] = tuple(view_shape) view_node.args = tuple(args) + view_node.meta["sharded"] = True ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}") + # if fused_weight_dims is provided, we need to update all split sizes + if fused_weight_dims is not None: + assert world_size is not None, "World size is required to update the split node params" + assert len(node.users) == 1, "Fused linear node should have only one user: a split node" + # find all split nodes in the region between this linear node and the next + split_nodes = subgraph( + [node], + [next_lin_node], + include=lambda n: is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]), + ) + for split_node in split_nodes: + orig_sizes = split_node.args[1] + new_sizes = [orig_sizes[i] // world_size for i in range(len(orig_sizes))] + args = list(split_node.args) + args[1] = new_sizes + split_node.args = tuple(args) + ad_logger.debug(f"\nUpdated split node {split_node} arguments to {split_node.args}") + -def _insert_sharded_matmul( +def shard_weight_tensor( gm: GraphModule, - node: Node, + weight_tensor: torch.Tensor, + param_key: str, dim: int, rank: int, world_size: int, - add_dist: bool = False, min_local_shape: int = 1, - quantization_cb: Optional[ - Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] - ] = None, -) -> None: - """Replace the matmul node with a new matmul node that accepts sharded weights. - - The state_dict is also updated to contain the sharded weights. + fused_weight_dims: Optional[list] = None, + requires_grad: bool = False, + update_param: bool = True, +) -> Tuple[torch.Tensor, torch.Size]: + """Shard a weight tensor across ranks and register load hook. + + Args: + gm: GraphModule containing the weight + weight_tensor: The weight tensor to shard + param_key: Parameter key for registering load hook + dim: Dimension to shard along + rank: Current rank + world_size: Total number of ranks + min_local_shape: Minimum local shape constraint (for GQA) + fused_weight_dims: List of dimensions for fused weights + custom_shard_fn: Optional custom function to shard the tensor + requires_grad: Whether the parameter should require gradients + update_param: Whether to update the parameter in the module + + Returns: + Tuple of (sharded_tensor, sharded_shape) """ - assert dim in [0, 1], "Only dim 0 and 1 are supported for sharding" - assert add_dist or dim == 0, "For dim=1 sharding, dist_op is required." def split_tensor( t: torch.Tensor, @@ -110,6 +170,251 @@ def split_tensor( return torch.tensor_split(t, max_split_size, dim=d)[r // num_groups] return torch.tensor_split(t, ws, dim=d)[r] + # Handle fused weights + if fused_weight_dims is not None: + + def split_fused_tensor( + t: torch.Tensor, + fused_dims: list = fused_weight_dims, + d: int = dim, + ) -> torch.Tensor: + # dim_d = t.shape[d] + # num_parts = 1 + # part_size = dim_d // num_parts + # fused_dims = [part_size] * num_parts + return torch.cat( + [split_tensor(w) for w in torch.split(t, fused_dims, dim=d)], + dim=d, + ) + + f_split = split_fused_tensor + else: + f_split = split_tensor + + sharded_weight = f_split(weight_tensor) + sharded_shape = sharded_weight.shape + + # Register load hook + gm._register_load_state_dict_pre_hook( + partial( + _load_hook, + f_split=f_split, + param_key=param_key, + param_shape=sharded_shape, + ) + ) + + # Update the parameter in the module + if update_param: + modname, _, param_name = param_key.rpartition(".") + submod = gm.get_submodule(modname) + param_new = nn.Parameter(sharded_weight.detach().clone(), requires_grad=requires_grad) + setattr(submod, param_name, param_new) + + return sharded_weight, sharded_shape + + +def get_all_weights_in_subgraph( + sources: list[Node], + sinks: list[Node], +): + """Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks.""" + weight_nodes = subgraph( + sources, sinks, include_boundary_nodes=False, include=lambda n: n.op == "get_attr" + ) + return weight_nodes + + +def _insert_sharded_mamba( + gm: GraphModule, + entry_node: Node, + dim: int, + rank: int, + world_size: int, + add_dist: bool = False, + min_local_shape: int = 1, + weights_to_shard: Optional[list[str]] = None, + weight_shard_dims: Optional[Dict[str, int]] = None, + fused_weight_dims: Optional[Dict[str, list]] = None, + quantization_cb: Optional[ + Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] + ] = None, +) -> bool: + """ + To shard Mamba layer, first column-shard the first linear layer: entry_node, + then shard all remaining weight tensors found in the subgraph defined between + entry_node and the next successor linear node. + First, validate if this is indeed a mamba module: within the subgraph, + there should be an torch_ssm node and conv1d node. + + Args: + gm: GraphModule + entry_node: The first linear node of the Mamba layer + dim: Default shard dimension + rank: Current rank + world_size: Total number of ranks + add_dist: Whether to add distribution op after entry_node + min_local_shape: Minimum local shape constraint + weights_to_shard: Optional list of regex patterns to match weight names + weight_shard_dims: Optional dict mapping weight keys to their shard dimensions + fused_weight_dims: Optional dict mapping weight keys to their fused dimension lists + quantization_cb: Optional quantization callback + """ + # Find next linear node to define subgraph boundary + try: + next_lin_node, depth = bfs(entry_node, is_linear_op, include_root=False) + except RuntimeError: + ad_logger.warning("Could not find next linear node after entry_node for Mamba sharding") + return False + + # Get subgraph between entry_node and next linear node + subgraph_nodes = subgraph([entry_node], [next_lin_node]) + + ############################################################## + ########## validate if this is a valid Mamba module ########## + ############################################################## + # has_ssm = any(is_op(n, torch.ops.auto_deploy.mamba.torch_ssm_transform) for n in subgraph_nodes) + has_ssm = True + conv1d_nodes = [ + n + for n in subgraph_nodes + if is_op(n, [torch.ops.aten.conv1d, torch.ops.auto_deploy.torch_causal_conv1d]) + ] + if len(conv1d_nodes) != 1 or not has_ssm: + ad_logger.warning( + f"Subgraph does not contain exactly one conv1d node and torch_ssm_transform. " + f"Skipping Mamba sharding. conv1d_nodes={conv1d_nodes}, has_ssm={has_ssm}" + ) + return False + + ############################################################## + ########## infer split sizes for in_proj and conv1d ########## + ############################################################## + # in_proj and conv1d are most likely fused, followed up by split nodes. Infer split sizes: + if fused_weight_dims is None: + split_nodes = [ + n + for n in subgraph_nodes + if is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]) + ] + if len(split_nodes) != 2: + ad_logger.warning( + f"Subgraph does not contain exactly two split nodes. " + f"Skipping Mamba sharding. split_nodes={split_nodes}" + ) + return False + split_sizes_1 = split_nodes[0].args[1] + split_sizes_2 = split_nodes[1].args[1] + if split_sizes_1[1] != sum(split_sizes_2): + ad_logger.warning( + f"Split nodes have different sizes. " + f"Skipping Mamba sharding. split_sizes_1={split_sizes_1}, split_sizes_2={split_sizes_2}" + ) + return False + fused_weight_dims = { + "in_proj": split_sizes_1[0:1] + split_sizes_2 + split_sizes_1[2:], + "conv1d": split_sizes_2, + } + + conv1d_node = conv1d_nodes[0] + # conv1d_node last argument is the number of output channels. + # This one is also sharded, so we need to update this parameter + conv_args = list(conv1d_node.args) + conv_args[-1] = conv1d_node.args[-1] // world_size + conv1d_node.args = tuple(conv_args) + + # First, shard the entry_node (the first linear layer) + # Extract entry node's fused_weight_dims by matching weight name against patterns + entry_fused_dims = None + if fused_weight_dims: + entry_weight_key, _ = extract_param_names_from_node(entry_node) + for pattern, dims in fused_weight_dims.items(): + if re.search(pattern, entry_weight_key): + entry_fused_dims = dims + break + + _shard_parameter_node( + gm=gm, + node=entry_node, + dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + add_dist=False, + min_local_shape=min_local_shape, + fused_weight_dims=entry_fused_dims, + ) + + # Get all weight nodes in the subgraph except for out_proj + weight_nodes = [ + n + for n in get_all_weights_in_subgraph([entry_node], [next_lin_node]) + if "out_proj" not in str(n) + ] + + # Shard remaining weights, such as conv1d or RMSNorm + for weight_node in weight_nodes: + weight_key = weight_node.target + + # Filter by regex patterns if provided + if weights_to_shard is not None: + if not any(pattern in weight_key for pattern in weights_to_shard): + continue + + # Determine shard dimension for this weight + shard_dim = weight_shard_dims.get(weight_key, dim) if weight_shard_dims else dim + + # Get the weight parameter + try: + weight_param = gm.get_parameter(weight_key) + except AttributeError: + ad_logger.debug(f"Could not get parameter for {weight_key}, skipping") + continue + + # Get fused dims for this weight if specified + fused_dims = None + for k, v in fused_weight_dims.items(): + if k in weight_key: + fused_dims = v + break + + # Shard the weight tensor (also updates the parameter in the module) + _, sharded_shape = shard_weight_tensor( + gm=gm, + weight_tensor=weight_param, + param_key=weight_key, + dim=shard_dim, + rank=rank, + world_size=world_size, + min_local_shape=min_local_shape, + fused_weight_dims=fused_dims, + ) + + ad_logger.debug( + f"Sharded weight {weight_key} on dim {shard_dim}: " + f"{weight_param.shape} -> {sharded_shape}" + ) + + +def _shard_parameter_node( + gm: GraphModule, + node: Node, + dim: int, + rank: int, + world_size: int, + add_dist: bool = False, + min_local_shape: int = 1, + fused_weight_dims: Optional[list] = None, + quantization_cb: Optional[ + Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] + ] = None, +) -> None: + """Replace the node with parametrized weight tensor with a new node that accepts sharded weights. + + The state_dict is also updated to contain the sharded weights. + """ + assert dim in [0, 1], "Only dim 0 and 1 are supported for sharding" + assert add_dist or dim == 0, "For dim=1 sharding, dist_op is required." + num_users = num_users_of_weight_node(node) if num_users > 1 or num_users == 0: ad_logger.warning( @@ -117,41 +422,36 @@ def split_tensor( ) return # get weight and bias key - weight_key, bias_key = extract_param_names_from_lin_node(node) + weight_key, bias_key = extract_param_names_from_node(node) modname = weight_key.rpartition(".")[0] submod = gm.get_submodule(modname) - def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> torch.Size: - # split or remove it - param_new = ( - None - if remove - else nn.Parameter( - split_tensor(gm.get_parameter(param_key)).detach().clone(), requires_grad=False - ) - ) - - # update the parameter - param_name = param_key.rpartition(".")[-1] - setattr(submod, param_name, param_new) - return torch.Size() if param_new is None else param_new.shape - - # update weight - weight_new_shape = set_new_param(submod, weight_key) - gm._register_load_state_dict_pre_hook( - partial( - _load_hook, f_split=split_tensor, param_key=weight_key, param_shape=weight_new_shape - ) + # Shard weight using the unified function (also updates the parameter) + original_weight = gm.get_parameter(weight_key) + _, weight_new_shape = shard_weight_tensor( + gm=gm, + weight_tensor=original_weight, + param_key=weight_key, + dim=dim, + rank=rank, + world_size=world_size, + min_local_shape=min_local_shape, + fused_weight_dims=fused_weight_dims, ) if bias_key is not None and dim == 0: # update bias for dim 0 --> we can handle it like the weight - bias_new_shape = set_new_param(submod, bias_key) - gm._register_load_state_dict_pre_hook( - partial( - _load_hook, f_split=split_tensor, param_key=bias_key, param_shape=bias_new_shape - ) + original_bias = gm.get_parameter(bias_key) + shard_weight_tensor( + gm=gm, + weight_tensor=original_bias, + param_key=bias_key, + dim=dim, + rank=rank, + world_size=world_size, + min_local_shape=min_local_shape, + fused_weight_dims=fused_weight_dims, ) elif bias_key is not None and rank != world_size - 1: # update the bias for dim 1 --> in this case only the last rank gets the bias to avoid @@ -161,7 +461,8 @@ def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> to args[2] = None node.args = tuple(args) gm.graph.erase_node(node_bias) - set_new_param(submod, bias_key, remove=True) + bias_param_name = bias_key.rpartition(".")[-1] + setattr(submod, bias_param_name, None) gm._register_load_state_dict_pre_hook(partial(_load_hook_remove, param_key=bias_key)) if quantization_cb is not None: @@ -176,9 +477,12 @@ def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> to world_size=world_size, ) - # column shard with no gather: the output is sharded + # # # column shard with no gather: the output is sharded if not add_dist: - _update_view_nodes(node) + if is_linear_op(node): + _validate_sharded_shapes( + node, fused_weight_dims=fused_weight_dims, world_size=world_size + ) return # figure out the right dist op @@ -195,6 +499,17 @@ def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> to dist_node.replace_input_with(dist_node, node) +def _update_node_args(node: Node, args: tuple) -> None: + """Update the node's arguments with the new sharded arguments.""" + if "sharded" in node.meta and node.meta["sharded"]: + return + node.args = args + node.meta["sharded"] = True + ad_logger.debug( + f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}." + ) + + class SplitDimension(IntEnum): """Enum for tensor split dimensions in sharding.""" @@ -242,15 +557,25 @@ def check_and_apply(self, gm: GraphModule, node: Node) -> bool: return True -class TPShardingInfo(ShardingTransformInfo): +class LayerType(Enum): + ATTENTION = "attention" + MAMBA = "mamba" + MLP = "mlp" + MOE = "moe" + + +class WeightShardingInfo(ShardingTransformInfo): """Configuration for TP sharding transformations.""" split_dim: SplitDimension dist_op: Optional[Literal["all_reduce", "all_gather"]] = None min_local_shape: int = 1 + layer_type: LayerType = LayerType.MLP + # used for TP sharding of fused weights + fused_weight_dims: Optional[list] = None @classmethod - def from_node(cls, node: Node, **kwargs) -> "TPShardingInfo": + def from_node(cls, node: Node, **kwargs) -> "WeightShardingInfo": """ Create the correct TPShardingInfo subclass (FP8/FP4/base) based on `node`. """ @@ -276,16 +601,47 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool: def apply(self, gm: GraphModule, node: Node) -> None: """Apply TP sharding transformation to the graph module.""" + if self.layer_type == LayerType.MAMBA: + _insert_sharded_mamba( + gm=gm, + entry_node=node, + dim=self.split_dim.value, + rank=self.rank, + world_size=self.world_size, + add_dist=self.dist_op is not None, + min_local_shape=self.min_local_shape, + fused_weight_dims=self.fused_weight_dims + if isinstance(self.fused_weight_dims, dict) + else None, + ) + else: + _shard_parameter_node( + gm=gm, + node=node, + dim=self.split_dim.value, + rank=self.rank, + world_size=self.world_size, + add_dist=self.dist_op is not None, + min_local_shape=self.min_local_shape, + fused_weight_dims=self.fused_weight_dims, + ) - _insert_sharded_matmul( - gm=gm, - node=node, - dim=self.split_dim.value, - rank=self.rank, - world_size=self.world_size, - add_dist=self.dist_op is not None, - min_local_shape=self.min_local_shape, - ) + +class ParameterUpdateInfo(ShardingTransformInfo): + """Configuration for node args sharding transformations.""" + + target_node: str + rank: int + world_size: int + args: tuple + + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: + """Validate the transformation configuration.""" + return len(node.args) == len(self.args) + + def apply(self, gm: GraphModule, node: Node) -> None: + """Apply the transformation to the graph module.""" + _update_node_args(node, self.args) class QuantizationShardingMixin(ABC): @@ -345,14 +701,11 @@ def quantization_cb( self.shard_load_hook, weight_name=weight_key, weight_shape=weight_new_shape, - dim=dim, - rank=rank, - world_size=world_size, ) ) -class FP8TPShardingInfo(QuantizationShardingMixin, TPShardingInfo): +class FP8TPShardingInfo(QuantizationShardingMixin, WeightShardingInfo): """Tensor-parallel sharding for FP8-quantized linears.""" def scale_names(self) -> List[str]: @@ -387,7 +740,7 @@ def shard_load_hook( return def apply(self, gm: GraphModule, node: Node) -> None: - _insert_sharded_matmul( + _shard_parameter_node( gm=gm, node=node, dim=self.split_dim.value, @@ -412,7 +765,7 @@ def _shard_fp4_weight_scale(weight_scale, sharded_uint8_weight_shape, dim, rank, ) -class FP4TPShardingInfo(QuantizationShardingMixin, TPShardingInfo): +class FP4TPShardingInfo(QuantizationShardingMixin, WeightShardingInfo): """Tensor-parallel sharding for FP4-quantized linears.""" def scale_names(self) -> List[str]: @@ -455,7 +808,7 @@ def shard_load_hook( ) def apply(self, gm: GraphModule, node: Node) -> None: - _insert_sharded_matmul( + _shard_parameter_node( gm=gm, node=node, dim=self.split_dim.value, @@ -480,7 +833,7 @@ def _resolve_tp_cls_from_node(node: Node): return cls except Exception: pass - return TPShardingInfo + return WeightShardingInfo class BMMShardingInfo(ShardingTransformInfo): @@ -536,28 +889,25 @@ def handle_tensor( end_idx: End index for sharding """ - # Define slice function for the sharding - def slice_tensor(t: torch.Tensor) -> torch.Tensor: - return t[start_idx:end_idx] - if tensor_node.op == "get_attr": - # Handle parameter tensor + # Handle parameter tensor using unified shard_weight_tensor weight_key = tensor_node.target - modname, _, param_name = weight_key.rpartition(".") param = gm.get_parameter(weight_key) - # Update the parameter with its shard - param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True) - gm.get_submodule(modname).register_parameter(param_name, param_new) - - # Register load state dict hook - gm._register_load_state_dict_pre_hook( - partial( - _load_hook, - f_split=slice_tensor, - param_key=weight_key, - param_shape=param_new.shape, - ) + # Define slice function for the sharding + def slice_tensor(t: torch.Tensor) -> torch.Tensor: + return t[start_idx:end_idx] + + # Use shard_weight_tensor with custom shard function (also updates the parameter) + shard_weight_tensor( + gm=gm, + weight_tensor=param, + param_key=weight_key, + dim=0, # BMM slices along batch dimension + rank=self.rank, + world_size=self.world_size, + custom_shard_fn=slice_tensor, + requires_grad=True, # BMM parameters require gradients ) else: # Handle dynamic tensor @@ -834,6 +1184,22 @@ def _resolve_ep_cls_from_node(node: Node) -> type[EPShardingInfo]: return EPShardingInfo +class ShardingSource(Enum): + """Enum for sharding source.""" + + HEURISTIC = "heuristic" + FACTORY = "factory" + + +class ShardingDim(Enum): + """Enum for sharding dimension.""" + + SSM = "ssm" + TP = "tp" + EP = "ep" + BMM = "bmm" + + class ShardingConfig(BaseModel): """Configuration for sharding the model.""" @@ -842,13 +1208,25 @@ class ShardingConfig(BaseModel): world_size: int = Field(default=1) predefined_config: Optional[Dict[str, Any]] = None simple_shard_only: bool = Field(default=False) - use_sharding_from_factory: bool = False support_partial_config: bool = False + sharding_source: List[ShardingSource] = Field( + default_factory=lambda: [ShardingSource.HEURISTIC] + ) sharding_dims: List[str] = Field(default_factory=list) - tp_transforms: List[TPShardingInfo] = Field(default_factory=list) + weight_sharding_transforms: List[WeightShardingInfo] = Field(default_factory=list) + parameter_update_transforms: List[ParameterUpdateInfo] = Field(default_factory=list) bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) ep_transforms: List[EPShardingInfo] = Field(default_factory=list) + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._transform_list_dict = { + WeightShardingInfo: self.weight_sharding_transforms, + BMMShardingInfo: self.bmm_transforms, + EPShardingInfo: self.ep_transforms, + ParameterUpdateInfo: self.parameter_update_transforms, + } + @model_validator(mode="after") def _validate_and_normalize(self): # Normalize empty dict to None for "no config" @@ -859,6 +1237,18 @@ def _validate_and_normalize(self): self.validate_config() return self + def add(self, transform: ShardingTransformInfo) -> bool: + """Append a TP transform only if that node was + not sharded before. Do not overwrite existing transforms. + """ + # try to add to appropriate transformation list + transform_list = self._transform_list_dict[type(transform)] + for existing_transform in transform_list: + if existing_transform.target_node == transform.target_node: + return False + transform_list.append(transform) + return True + def validate_config(self) -> bool: if self.factory_source != ShardingConfigSource.HUGGINGFACE: ad_logger.warning( @@ -909,27 +1299,3 @@ def validate_config(self) -> bool: def get_predefined_config(self) -> Dict[str, Any]: return self.predefined_config - - -def _append_simple_shard( - nodes_linear: Dict[Node, List[Node]], - rank: int, - world_size: int, - sharding_config: ShardingConfig, -) -> None: - # for every linear node: - # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) - tp_shards: List[TPShardingInfo] = [] - for node_group in nodes_linear.values(): - for n in node_group: - tp_shards.append( - TPShardingInfo( - target_node=n.name, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, - ) - ) - sharding_config.tp_transforms.extend(tp_shards) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 76d48669d61..58855fb0318 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -15,7 +15,7 @@ from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ( SplitDimension, - TPShardingInfo, + WeightShardingInfo, ) from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op @@ -272,7 +272,7 @@ def _run_pattern_detection_job( dim = SplitDimension.COLUMN dist_op = None expected_transformations.append( - TPShardingInfo( + WeightShardingInfo( target_node=node.name, split_dim=dim, rank=rank, @@ -293,7 +293,7 @@ def _run_pattern_detection_job( dim = SplitDimension.ROW dist_op = "all_reduce" expected_transformations.append( - TPShardingInfo( + WeightShardingInfo( target_node=node.name, split_dim=dim, rank=rank, @@ -307,7 +307,7 @@ def _run_pattern_detection_job( for node in gm.graph.nodes: if is_linear_op(node): expected_transformations.append( - TPShardingInfo( + WeightShardingInfo( target_node=node.name, split_dim=SplitDimension.COLUMN, # Simple shard uses dim=0 rank=rank, @@ -351,7 +351,7 @@ def _run_pattern_detection_job( optimizer.shared_config.local_rank = rank optimizer.shared_config.world_size = world_size _ = optimizer(None, gm) - detected_transformations = optimizer.shared_config.sharding_config.tp_transforms + detected_transformations = optimizer.shared_config.sharding_config.weight_sharding_transforms print(f"detected_transformations: {detected_transformations}") print(f"expected_transformations: {expected_transformations}")