Skip to content
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ transforms:
detect_sharding:
stage: sharding
simple_shard_only: false
use_sharding_from_factory: false
support_partial_config: false
# sharding_source: ['factory', 'custom', 'heuristic']
sharding_source: ['heuristic']
support_partial_config: true
# custom_sharding_config: 'tp_sharding.yaml'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this is just a leftover from testing and should be reverted?

sharding_dims: ['tp', 'ep', 'bmm']
requires_shape_prop: true
# TODO: (hg) need to ensure run_shape_prop after sharding.
Expand Down
11 changes: 0 additions & 11 deletions tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
"If False, auto-detect and use column+row (all_reduce) sharding when possible.",
)

use_sharding_from_factory: bool = Field(
default=False,
description="If True, use sharding from the model factory. If False, use sharding from the "
"AutoDeployConfig.",
)

sharding_dims: List[str] = Field(
default=["tp", "ep", "dp"],
description="The sharding methods to apply by the heuristic sharding stage.",
)

compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
Field(
default="torch-compile",
Expand Down
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/transform/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ def from_last_info(cls, info: "TransformInfo") -> "TransformInfo":
has_valid_shapes=info.has_valid_shapes,
)

# overload += operator to concatenate TransformInfo objects
def __iadd__(self, other: "TransformInfo") -> "TransformInfo":
# since TransformInfo is frozen, instead, we return a new TransformInfo object
return TransformInfo(
skipped=self.skipped & other.skipped,
num_matches=self.num_matches + other.num_matches,
is_clean=self.is_clean & other.is_clean,
has_valid_shapes=self.has_valid_shapes & other.has_valid_shapes,
)
Comment on lines +159 to +166
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the existing __add__ operator instead. __iadd__ is by convention an in-place operator, i.e., it means that

config1 = TransformInfo()
config2 = TransformInfo() 
config3 = config1
config3 += config2
assert config3 is config1  # is operator checks for same object!

However, this assertion would fail since you actually create a new object


def __or__(self, other: "TransformInfo") -> "TransformInfo":
"""Merge two TransformInfo objects."""
return TransformInfo(
Expand Down
300 changes: 145 additions & 155 deletions tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Large diffs are not rendered by default.

61 changes: 59 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Sharding config definitions for the inference optimizer."""

import json
import math
import operator
from abc import ABC, abstractmethod
from enum import IntEnum
from enum import Enum, IntEnum
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence

import torch
import torch.nn as nn
import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator
from torch.fx import GraphModule, Node

Expand Down Expand Up @@ -834,16 +837,35 @@ def _resolve_ep_cls_from_node(node: Node) -> type[EPShardingInfo]:
return EPShardingInfo


class ShardingSource(Enum):
"""Enum for sharding source."""

HEURISTIC = "heuristic"
FACTORY = "factory"
CUSTOM = "custom"


class ShardingDim(Enum):
"""Enum for sharding dimension."""

TP = "tp"
EP = "ep"
BMM = "bmm"


class ShardingConfig(BaseModel):
"""Configuration for sharding the model."""

factory_source: ShardingConfigSource = Field(default=ShardingConfigSource.UNKNOWN)
rank: int = Field(default=0)
world_size: int = Field(default=1)
predefined_config: Optional[Dict[str, Any]] = None
custom_sharding_config: Optional[Dict[str, Any]] = None
simple_shard_only: bool = Field(default=False)
use_sharding_from_factory: bool = False
support_partial_config: bool = False
sharding_source: List[ShardingSource] = Field(
default_factory=lambda: [ShardingSource.HEURISTIC]
)
sharding_dims: List[str] = Field(default_factory=list)
tp_transforms: List[TPShardingInfo] = Field(default_factory=list)
bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list)
Expand All @@ -859,6 +881,41 @@ def _validate_and_normalize(self):
self.validate_config()
return self

def read_custom_sharding_config(self, config_path: str) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to avoid adding a separate yaml object just for sharding. We already have a general-purpose config reader otherwise it gets too complicated. There is no need to add a separate yaml reader

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see this comment as well for more details: https://github.com/NVIDIA/TensorRT-LLM/pull/8153/files#r2437473448

"""Read the custom sharding config from the given path.

Supports both JSON and YAML file formats. The format is auto-detected
based on the file extension (.json, .yaml, .yml).
"""
path = Path(config_path)

if not path.exists():
ad_logger.warning(f"Sharding config file not found: {config_path}")
return False

try:
with open(config_path, "r") as f:
if path.suffix.lower() in [".yaml", ".yml"]:
self.custom_sharding_config = yaml.safe_load(f)
elif path.suffix.lower() == ".json":
self.custom_sharding_config = json.load(f)
else:
ad_logger.warning(f"Unsupported sharding config file format: {path.suffix}")
except Exception as e:
ad_logger.warning(f"Failed to read sharding config file: {e}")
return False
return True

def append_TP(self, tp_transform: TPShardingInfo) -> bool:
"""Append a TP transform only if that node was
not sharded before. Do not overwrite existing transforms.
"""
for existing_transform in self.tp_transforms:
if existing_transform.target_node == tp_transform.target_node:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this sufficient to avoid conflicting/duplicate configurations?

return False
self.tp_transforms.append(tp_transform)
return True

def validate_config(self) -> bool:
if self.factory_source != ShardingConfigSource.HUGGINGFACE:
ad_logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,6 @@ def run_sharding_pattern_detection_test(
print("detected_set", detected_set)
print("expected_set", expected_set)

assert detected_set == expected_set, "Expected sharding pattern does not match detected pattern"
assert detected_set == expected_set, (
f"Expected sharding pattern does not match detected pattern: {detected_set} != {expected_set}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def _run_job(
{
"detect_sharding": {
"stage": "sharding",
"use_sharding_from_factory": False,
"sharding_source": ["heuristic"],
"sharding_dims": ["bmm"],
"support_partial_config": False,
},
"sharding_transform_executor": {
"stage": "sharding",
Expand Down Expand Up @@ -128,7 +129,8 @@ def _run_pattern_detection_job(
{
"detect_sharding": {
"stage": "sharding",
"use_sharding_from_factory": False,
"sharding_source": ["heuristic"],
"support_partial_config": False,
},
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def _get_expected_num_params(rank: int, world_size: int, num_p_og: int) -> int:
{
"detect_sharding": {
"stage": "sharding",
"use_sharding_from_factory": False,
"sharding_source": ["heuristic"],
"sharding_dims": ["ep"],
"support_partial_config": False,
},
"sharding_transform_executor": {
"stage": "sharding",
Expand Down Expand Up @@ -118,7 +119,8 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) ->
{
"detect_sharding": {
"stage": "sharding",
"use_sharding_from_factory": False,
"sharding_source": ["heuristic"],
"support_partial_config": False,
},
},
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Tests for basic graph sharding."""

# add to the path directory 4 directories up
import os
from functools import partial
from typing import Type

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
from _dist_test_utils import get_device_counts
from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm
from _model_test_utils import FakeFP8Linear
Expand Down Expand Up @@ -193,12 +196,22 @@ def verify_local_weight_sizes(gm) -> bool:
op_expected = getattr(torch.ops.auto_deploy, dist_op_expected)

gm = torch_export_to_gm(model, args=(x,), clone=True)
sharding_source = ["custom"] if from_config else ["heuristic"]

if sharding_source == ["custom"]:
# If the file does not exist, write predefined_config to tp_sharding.yaml file
if not os.path.exists("tp_sharding.yaml"):
with open("tp_sharding.yaml", "w") as f:
yaml.dump(predefined_config, f, sort_keys=False)
Comment on lines +204 to +205
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps use python's tempfile, to avoid contaminating the current working dir.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with tempfile is that anyway this file has to be visible from a different thread directly form a disk, so i cannot use a context like:

with tempfile.NamedTemporaryFile(mode='w+t', delete=True) as tmpfile:
    # Write to the file
    yaml.dump(predefined_config, tmpfile, sort_keys=False)
    ...

tempfile adds some unique id either to temporary file or temporary diectory, but I need a fixed absolute path in custom_sharding_config parameter to read it from.

Unless, you know a good workaround to it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my other comment here: https://github.com/NVIDIA/TensorRT-LLM/pull/8153/files#r2437473448

If you do you should be able to just provide the custom config as dictionary without needing to create/read a tmp file

gm_transformed = InferenceOptimizer(
None,
{
"detect_sharding": {
"stage": "sharding",
"use_sharding_from_factory": from_config,
"sharding_source": sharding_source,
"custom_sharding_config": "tp_sharding.yaml",
"support_partial_config": False,
"sharding_dims": ["tp"],
},
"sharding_transform_executor": {
"stage": "sharding",
Expand Down Expand Up @@ -338,23 +351,33 @@ def _run_pattern_detection_job(
)
)

sharding_source = ["custom"] if from_config else ["heuristic"]

if sharding_source == ["custom"]:
# If the file does not exist, write predefined_config to tp_sharding.yaml file
if not os.path.exists("tp_sharding.yaml"):
with open("tp_sharding.yaml", "w") as f:
yaml.dump(predefined_config, f, sort_keys=False)

# get detected transformations
optimizer = InferenceOptimizer(
None,
{
"detect_sharding": {
"stage": "sharding",
"use_sharding_from_factory": from_config,
"sharding_source": sharding_source,
"custom_sharding_config": "tp_sharding.yaml",
"support_partial_config": False,
"sharding_dims": ["tp"],
},
},
)
optimizer.shared_config.local_rank = rank
optimizer.shared_config.world_size = world_size
optimizer.shared_config.sharding_config.predefined_config = predefined_config
_ = optimizer(None, gm)
detected_transformations = optimizer.shared_config.sharding_config.tp_transforms

print(f"detected_transformations: {detected_transformations}")
print(f"expected_transformations: {expected_transformations}")
# Run pattern detection test
run_sharding_pattern_detection_test(detected_transformations, expected_transformations)

Expand Down Expand Up @@ -409,7 +432,3 @@ def test_sharding_pattern_detection(
No need to run distributed job, can be run on single process.
"""
_run_pattern_detection_job(model_cls, bias, 0, world_size, from_config)


if __name__ == "__main__":
_run_pattern_detection_job(nn.Linear, False, 0, 8, False)