diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index ead86ade2f6..e32512a9f7f 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -73,8 +73,10 @@ transforms: detect_sharding: stage: sharding simple_shard_only: false - use_sharding_from_factory: false - support_partial_config: false + # sharding_source: ['factory', 'custom', 'heuristic'] + sharding_source: ['heuristic'] + support_partial_config: true + # custom_sharding_config: 'tp_sharding.yaml' sharding_dims: ['tp', 'ep', 'bmm'] requires_shape_prop: true # TODO: (hg) need to ensure run_shape_prop after sharding. diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 7cd235cefd4..ca6e3689a88 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -163,17 +163,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): "If False, auto-detect and use column+row (all_reduce) sharding when possible.", ) - use_sharding_from_factory: bool = Field( - default=False, - description="If True, use sharding from the model factory. If False, use sharding from the " - "AutoDeployConfig.", - ) - - sharding_dims: List[str] = Field( - default=["tp", "ep", "dp"], - description="The sharding methods to apply by the heuristic sharding stage.", - ) - compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = ( Field( default="torch-compile", diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 6915dac8540..d6f6b764ff3 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -155,6 +155,16 @@ def from_last_info(cls, info: "TransformInfo") -> "TransformInfo": has_valid_shapes=info.has_valid_shapes, ) + # overload += operator to concatenate TransformInfo objects + def __iadd__(self, other: "TransformInfo") -> "TransformInfo": + # since TransformInfo is frozen, instead, we return a new TransformInfo object + return TransformInfo( + skipped=self.skipped & other.skipped, + num_matches=self.num_matches + other.num_matches, + is_clean=self.is_clean & other.is_clean, + has_valid_shapes=self.has_valid_shapes & other.has_valid_shapes, + ) + def __or__(self, other: "TransformInfo") -> "TransformInfo": """Merge two TransformInfo objects.""" return TransformInfo( diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index d6de54b22d2..e4e61f3418e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -39,6 +39,8 @@ BMMShardingInfo, EPShardingInfo, ShardingConfig, + ShardingDim, + ShardingSource, ShardingTransformInfo, SplitDimension, TPShardingInfo, @@ -106,30 +108,36 @@ def _append_simple_shard( ) -> None: # for every linear node: # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) - tp_shards: List[TPShardingInfo] = [] + num_shards = 0 for node_group in nodes_linear.values(): for n in node_group: - tp_shards.append( - TPShardingInfo.from_node( - n, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, + num_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + n, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) ) ) - sharding_config.tp_transforms.extend(tp_shards) + return num_shards class ShardingTransformConfig(TransformConfig): """Configuration for sharding transformations.""" simple_shard_only: bool = Field(default=False) - use_sharding_from_factory: bool = Field(default=False) - support_partial_config: bool = Field(default=False) - # Which sharding families to run: any subset of {"tp", "ep", "bmm"} - sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"]) + sharding_source: List[ShardingSource] = Field( + default_factory=lambda: [ShardingSource.HEURISTIC] + ) + custom_sharding_config: str = Field(default="") + # Which sharding dimensions to run: any subset of {"tp", "ep", "bmm"} + sharding_dims: List[ShardingDim] = Field( + default_factory=lambda: [ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM] + ) @TransformRegistry.register("detect_sharding") @@ -187,58 +195,53 @@ def _apply( ) shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only shared_config.sharding_config.support_partial_config = self.config.support_partial_config + shared_config.sharding_config.sharding_source = self.config.sharding_source + if isinstance(self.config.custom_sharding_config, str): + shared_config.sharding_config.read_custom_sharding_config( + self.config.custom_sharding_config + ) shared_config.sharding_config.sharding_dims = self.config.sharding_dims - shared_config.sharding_config.use_sharding_from_factory = ( - self.config.use_sharding_from_factory - ) - sharding_config = shared_config.sharding_config sharding_config.validate_config() - if ( - shared_config.sharding_config.use_sharding_from_factory - and len(shared_config.sharding_config.get_predefined_config()) > 0 - ): - ad_logger.info("Applying sharding from config") - factory_info = detect_sharding_from_factory_config(gm, sharding_config) - return gm, factory_info + info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True) + for source in shared_config.sharding_config.sharding_source: + if source == ShardingSource.FACTORY: + if len(shared_config.sharding_config.get_predefined_config()) == 0: + ad_logger.warning( + "No factory config found. Skipping sharding from factory config" + ) + continue + ad_logger.info("Applying sharding from factory config") + info += detect_sharding_from_factory_config(gm, sharding_config) - ad_logger.info( - f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}" - ) - # run TP sharding across ranks - if "tp" in shared_config.sharding_config.sharding_dims: - tp_info = detect_column_row_shard(gm, sharding_config) - else: - tp_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + elif source == ShardingSource.CUSTOM: + ad_logger.info("Applying sharding from custom config") + if sharding_config.custom_sharding_config is None: + ad_logger.warning( + "No custom sharding config found. Skipping sharding from custom config" + ) + continue + sharding_config.predefined_config = sharding_config.custom_sharding_config + info += detect_sharding_from_factory_config(gm, sharding_config) - # run EP sharding across ranks - if "ep" in shared_config.sharding_config.sharding_dims: - ep_info = detect_ep_shard(gm, sharding_config) - else: - ep_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + elif source == ShardingSource.HEURISTIC: + ad_logger.info( + f"Running autodeploy sharding heuristics: {sharding_config.sharding_dims}" + ) + # run TP sharding across ranks + if ShardingDim.TP in sharding_config.sharding_dims: + info += detect_column_row_shard(gm, sharding_config) - # run BMM sharding across ranks - if "bmm" in shared_config.sharding_config.sharding_dims: - dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config) - else: - dp_bmm_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + # run EP sharding across ranks + if ShardingDim.EP in sharding_config.sharding_dims: + info += detect_ep_shard(gm, sharding_config) + + # run BMM sharding across ranks + if ShardingDim.BMM in sharding_config.sharding_dims: + info += detect_dp_bmm_shard(gm, sharding_config) - info = TransformInfo( - skipped=tp_info.skipped and ep_info.skipped and dp_bmm_info.skipped, - num_matches=tp_info.num_matches + ep_info.num_matches + dp_bmm_info.num_matches, - is_clean=tp_info.is_clean and ep_info.is_clean and dp_bmm_info.is_clean, - has_valid_shapes=tp_info.has_valid_shapes - and ep_info.has_valid_shapes - and dp_bmm_info.has_valid_shapes, - ) return gm, info @@ -288,7 +291,6 @@ def detect_sharding_from_factory_config( "o_proj", ] - num_shards = 0 num_simple_shards = 0 num_row_col_shards = 0 @@ -310,32 +312,34 @@ def detect_sharding_from_factory_config( pattern_string = pattern_string.replace("*", "@") pattern_regex = re.escape(pattern_string).replace("@", ".*") if re.match(pattern_regex, module_name): - num_shards += 1 # we have a match. Get the config for this layer config = tp_plan[key] if config == "colwise": - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - lin_node, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op=None, - min_local_shape=min_local_shape, + num_row_col_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + ) ) ) elif config == "rowwise": - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - lin_node, - split_dim=SplitDimension.ROW, - rank=rank, - world_size=world_size, - dist_op="all_reduce", - min_local_shape=min_local_shape, + num_row_col_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_reduce", + min_local_shape=min_local_shape, + ) ) ) - num_row_col_shards += 1 elif "sequence" in config: # TODO: Sequence parallelism is not supported yet. ad_logger.warning("Sequence parallelism is not supported yet. Skipping.") @@ -345,28 +349,31 @@ def detect_sharding_from_factory_config( if "shared" in module_name: col_row_action = config.replace("local_", "") if col_row_action == "colwise": - sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=lin_node.name, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op=None, - min_local_shape=min_local_shape, + num_row_col_shards += int( + sharding_config.append_TP( + TPShardingInfo( + target_node=lin_node.name, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + ) ) ) elif col_row_action == "rowwise": - sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=lin_node.name, - split_dim=SplitDimension.ROW, - rank=rank, - world_size=world_size, - dist_op="all_reduce", - min_local_shape=min_local_shape, + num_row_col_shards += int( + sharding_config.append_TP( + TPShardingInfo( + target_node=lin_node.name, + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_reduce", + min_local_shape=min_local_shape, + ) ) ) - num_row_col_shards += 1 else: ad_logger.warning(f"Unsupported sharding action {config}. Skipping.") else: @@ -375,64 +382,44 @@ def detect_sharding_from_factory_config( elif "gather" in config: # Simple shard (row + all_gather) - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - lin_node, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, + num_simple_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) ) ) - num_simple_shards += 1 else: ad_logger.warning( f"Unsupported sharding action {config}. Fallback to simple shard" ) - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - lin_node, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, + num_simple_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) ) ) # after successful match, break the loop break ad_logger.info( - f"Applied {num_shards} TP shards (simple: {num_simple_shards}, " - f"row-col pattern: {num_row_col_shards})" + f"Applied {num_simple_shards + num_row_col_shards} TP shards (simple: {num_simple_shards}, " + f"row-col: {num_row_col_shards})" ) num_matches = len(sharding_config.tp_transforms) - if sharding_config.support_partial_config: - ad_logger.info( - f"Partial factory config applied only for TP. " - f"Applying heuristics for {sharding_config.sharding_dims}." - ) - - # run EP sharding across ranks - if "ep" in sharding_config.sharding_dims: - ep_info = detect_ep_shard(gm, sharding_config) - else: - ep_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) - - # run BMM sharding across ranks - if "bmm" in sharding_config.sharding_dims: - dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config) - else: - dp_bmm_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) - num_matches += ep_info.num_matches + dp_bmm_info.num_matches - return TransformInfo( skipped=False, num_matches=num_matches, @@ -514,7 +501,6 @@ def detect_column_row_shard( # col_split (dim 1) 2nd group + all_reduce output of 2nd group # 3. Linear nodes that are not in two groups or we cannot account for all nodes: # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) output - num_shards = 0 num_simple_shards = 0 num_row_col_shards = 0 for n_start, n_end in zip(boundary_nodes[:-1], boundary_nodes[1:]): @@ -541,19 +527,19 @@ def detect_column_row_shard( if len(nodes_linear) == 0: continue - num_shards += 1 - if sharding_config.simple_shard_only: ad_logger.debug(f"Forcing Simple Shard: Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += int( + _append_simple_shard(nodes_linear, rank, world_size, sharding_config) + ) continue # simple shard when we have != 2 groups of linear nodes if len(nodes_linear) != 2: ad_logger.debug(f"Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += int( + _append_simple_shard(nodes_linear, rank, world_size, sharding_config) + ) continue # let's look at the unnacounted nodes. They are okay as long as they fall before the @@ -583,8 +569,9 @@ def detect_column_row_shard( # check if any unaccounted nodes are left. If so, do a simply shard if unaccounted_nodes or attention_related_nodes: ad_logger.debug(f"Unaccounted nodes: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += int( + _append_simple_shard(nodes_linear, rank, world_size, sharding_config) + ) continue # If we can account for all sharded nodes, we can do a two-way shard @@ -596,8 +583,9 @@ def detect_column_row_shard( # Column-row shard boundary region detection is probably wrong - there should be # only one attention operation. Fall back to simple shard. ad_logger.debug(f"More than one attention node: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += int( + _append_simple_shard(nodes_linear, rank, world_size, sharding_config) + ) continue # Extract head dimension. We cannot shard below the head_dim size. # Assume that head_dim is the last (innermost) dimension of the tensor @@ -610,18 +598,20 @@ def detect_column_row_shard( dist_op = "all_reduce" else: dist_op = None - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - n, - split_dim=i, - rank=rank, - world_size=world_size, - dist_op=dist_op, - min_local_shape=min_local_shape, + num_row_col_shards += int( + sharding_config.append_TP( + TPShardingInfo.from_node( + n, + split_dim=i, + rank=rank, + world_size=world_size, + dist_op=dist_op, + min_local_shape=min_local_shape, + ) ) ) - num_row_col_shards += 1 + num_shards = num_simple_shards + num_row_col_shards ad_logger.info( f"Found {num_shards} TP shards (simple: {num_simple_shards}, row-col: {num_row_col_shards})" ) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index a95045b7d28..93f7a9fad78 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -1,14 +1,17 @@ """Sharding config definitions for the inference optimizer.""" +import json import math import operator from abc import ABC, abstractmethod -from enum import IntEnum +from enum import Enum, IntEnum from functools import partial +from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Sequence import torch import torch.nn as nn +import yaml from pydantic import BaseModel, ConfigDict, Field, model_validator from torch.fx import GraphModule, Node @@ -834,6 +837,22 @@ def _resolve_ep_cls_from_node(node: Node) -> type[EPShardingInfo]: return EPShardingInfo +class ShardingSource(Enum): + """Enum for sharding source.""" + + HEURISTIC = "heuristic" + FACTORY = "factory" + CUSTOM = "custom" + + +class ShardingDim(Enum): + """Enum for sharding dimension.""" + + TP = "tp" + EP = "ep" + BMM = "bmm" + + class ShardingConfig(BaseModel): """Configuration for sharding the model.""" @@ -841,9 +860,12 @@ class ShardingConfig(BaseModel): rank: int = Field(default=0) world_size: int = Field(default=1) predefined_config: Optional[Dict[str, Any]] = None + custom_sharding_config: Optional[Dict[str, Any]] = None simple_shard_only: bool = Field(default=False) - use_sharding_from_factory: bool = False support_partial_config: bool = False + sharding_source: List[ShardingSource] = Field( + default_factory=lambda: [ShardingSource.HEURISTIC] + ) sharding_dims: List[str] = Field(default_factory=list) tp_transforms: List[TPShardingInfo] = Field(default_factory=list) bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) @@ -859,6 +881,41 @@ def _validate_and_normalize(self): self.validate_config() return self + def read_custom_sharding_config(self, config_path: str) -> bool: + """Read the custom sharding config from the given path. + + Supports both JSON and YAML file formats. The format is auto-detected + based on the file extension (.json, .yaml, .yml). + """ + path = Path(config_path) + + if not path.exists(): + ad_logger.warning(f"Sharding config file not found: {config_path}") + return False + + try: + with open(config_path, "r") as f: + if path.suffix.lower() in [".yaml", ".yml"]: + self.custom_sharding_config = yaml.safe_load(f) + elif path.suffix.lower() == ".json": + self.custom_sharding_config = json.load(f) + else: + ad_logger.warning(f"Unsupported sharding config file format: {path.suffix}") + except Exception as e: + ad_logger.warning(f"Failed to read sharding config file: {e}") + return False + return True + + def append_TP(self, tp_transform: TPShardingInfo) -> bool: + """Append a TP transform only if that node was + not sharded before. Do not overwrite existing transforms. + """ + for existing_transform in self.tp_transforms: + if existing_transform.target_node == tp_transform.target_node: + return False + self.tp_transforms.append(tp_transform) + return True + def validate_config(self) -> bool: if self.factory_source != ShardingConfigSource.HUGGINGFACE: ad_logger.warning( diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py index 13e8d4d0040..e554e2c572b 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py @@ -274,4 +274,6 @@ def run_sharding_pattern_detection_test( print("detected_set", detected_set) print("expected_set", expected_set) - assert detected_set == expected_set, "Expected sharding pattern does not match detected pattern" + assert detected_set == expected_set, ( + f"Expected sharding pattern does not match detected pattern: {detected_set} != {expected_set}" + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py index add1b399229..784675757be 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py @@ -66,8 +66,9 @@ def _run_job( { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": False, + "sharding_source": ["heuristic"], "sharding_dims": ["bmm"], + "support_partial_config": False, }, "sharding_transform_executor": { "stage": "sharding", @@ -128,7 +129,8 @@ def _run_pattern_detection_job( { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": False, + "sharding_source": ["heuristic"], + "support_partial_config": False, }, }, ) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 94e236cd4e4..c2cc7ed46bb 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -50,8 +50,9 @@ def _get_expected_num_params(rank: int, world_size: int, num_p_og: int) -> int: { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": False, + "sharding_source": ["heuristic"], "sharding_dims": ["ep"], + "support_partial_config": False, }, "sharding_transform_executor": { "stage": "sharding", @@ -118,7 +119,8 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": False, + "sharding_source": ["heuristic"], + "support_partial_config": False, }, }, ) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 76d48669d61..f868cd8de35 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -1,5 +1,7 @@ """Tests for basic graph sharding.""" +# add to the path directory 4 directories up +import os from functools import partial from typing import Type @@ -7,6 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import yaml from _dist_test_utils import get_device_counts from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm from _model_test_utils import FakeFP8Linear @@ -193,12 +196,22 @@ def verify_local_weight_sizes(gm) -> bool: op_expected = getattr(torch.ops.auto_deploy, dist_op_expected) gm = torch_export_to_gm(model, args=(x,), clone=True) + sharding_source = ["custom"] if from_config else ["heuristic"] + + if sharding_source == ["custom"]: + # If the file does not exist, write predefined_config to tp_sharding.yaml file + if not os.path.exists("tp_sharding.yaml"): + with open("tp_sharding.yaml", "w") as f: + yaml.dump(predefined_config, f, sort_keys=False) gm_transformed = InferenceOptimizer( None, { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": from_config, + "sharding_source": sharding_source, + "custom_sharding_config": "tp_sharding.yaml", + "support_partial_config": False, + "sharding_dims": ["tp"], }, "sharding_transform_executor": { "stage": "sharding", @@ -338,23 +351,33 @@ def _run_pattern_detection_job( ) ) + sharding_source = ["custom"] if from_config else ["heuristic"] + + if sharding_source == ["custom"]: + # If the file does not exist, write predefined_config to tp_sharding.yaml file + if not os.path.exists("tp_sharding.yaml"): + with open("tp_sharding.yaml", "w") as f: + yaml.dump(predefined_config, f, sort_keys=False) + # get detected transformations optimizer = InferenceOptimizer( None, { "detect_sharding": { "stage": "sharding", - "use_sharding_from_factory": from_config, + "sharding_source": sharding_source, + "custom_sharding_config": "tp_sharding.yaml", + "support_partial_config": False, + "sharding_dims": ["tp"], }, }, ) optimizer.shared_config.local_rank = rank optimizer.shared_config.world_size = world_size + optimizer.shared_config.sharding_config.predefined_config = predefined_config _ = optimizer(None, gm) detected_transformations = optimizer.shared_config.sharding_config.tp_transforms - print(f"detected_transformations: {detected_transformations}") - print(f"expected_transformations: {expected_transformations}") # Run pattern detection test run_sharding_pattern_detection_test(detected_transformations, expected_transformations) @@ -409,7 +432,3 @@ def test_sharding_pattern_detection( No need to run distributed job, can be run on single process. """ _run_pattern_detection_job(model_cls, bias, 0, world_size, from_config) - - -if __name__ == "__main__": - _run_pattern_detection_job(nn.Linear, False, 0, 8, False)