-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[TRTLLM-6342][feat] Support custom sharding config source #8153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3433f1a
c105172
8ae8869
1af20e0
3d1eb43
0972ed6
03f4190
0cf4b2a
c46d742
8765bbe
99bca89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,6 +155,16 @@ def from_last_info(cls, info: "TransformInfo") -> "TransformInfo": | |
has_valid_shapes=info.has_valid_shapes, | ||
) | ||
|
||
# overload += operator to concatenate TransformInfo objects | ||
def __iadd__(self, other: "TransformInfo") -> "TransformInfo": | ||
# since TransformInfo is frozen, instead, we return a new TransformInfo object | ||
return TransformInfo( | ||
skipped=self.skipped & other.skipped, | ||
num_matches=self.num_matches + other.num_matches, | ||
is_clean=self.is_clean & other.is_clean, | ||
has_valid_shapes=self.has_valid_shapes & other.has_valid_shapes, | ||
) | ||
Comment on lines
+159
to
+166
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use the existing
However, this assertion would fail since you actually create a new object |
||
|
||
def __or__(self, other: "TransformInfo") -> "TransformInfo": | ||
"""Merge two TransformInfo objects.""" | ||
return TransformInfo( | ||
|
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,17 @@ | ||
"""Sharding config definitions for the inference optimizer.""" | ||
|
||
import json | ||
import math | ||
import operator | ||
from abc import ABC, abstractmethod | ||
from enum import IntEnum | ||
from enum import Enum, IntEnum | ||
from functools import partial | ||
from pathlib import Path | ||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence | ||
|
||
import torch | ||
import torch.nn as nn | ||
import yaml | ||
from pydantic import BaseModel, ConfigDict, Field, model_validator | ||
from torch.fx import GraphModule, Node | ||
|
||
|
@@ -834,16 +837,35 @@ def _resolve_ep_cls_from_node(node: Node) -> type[EPShardingInfo]: | |
return EPShardingInfo | ||
|
||
|
||
class ShardingSource(Enum): | ||
"""Enum for sharding source.""" | ||
|
||
HEURISTIC = "heuristic" | ||
FACTORY = "factory" | ||
CUSTOM = "custom" | ||
|
||
|
||
class ShardingDim(Enum): | ||
"""Enum for sharding dimension.""" | ||
|
||
TP = "tp" | ||
EP = "ep" | ||
BMM = "bmm" | ||
|
||
|
||
class ShardingConfig(BaseModel): | ||
"""Configuration for sharding the model.""" | ||
|
||
factory_source: ShardingConfigSource = Field(default=ShardingConfigSource.UNKNOWN) | ||
rank: int = Field(default=0) | ||
world_size: int = Field(default=1) | ||
predefined_config: Optional[Dict[str, Any]] = None | ||
custom_sharding_config: Optional[Dict[str, Any]] = None | ||
simple_shard_only: bool = Field(default=False) | ||
use_sharding_from_factory: bool = False | ||
support_partial_config: bool = False | ||
sharding_source: List[ShardingSource] = Field( | ||
default_factory=lambda: [ShardingSource.HEURISTIC] | ||
) | ||
sharding_dims: List[str] = Field(default_factory=list) | ||
tp_transforms: List[TPShardingInfo] = Field(default_factory=list) | ||
bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) | ||
|
@@ -859,6 +881,41 @@ def _validate_and_normalize(self): | |
self.validate_config() | ||
return self | ||
|
||
def read_custom_sharding_config(self, config_path: str) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to avoid adding a separate yaml object just for sharding. We already have a general-purpose config reader otherwise it gets too complicated. There is no need to add a separate yaml reader There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see this comment as well for more details: https://github.com/NVIDIA/TensorRT-LLM/pull/8153/files#r2437473448 |
||
"""Read the custom sharding config from the given path. | ||
|
||
Supports both JSON and YAML file formats. The format is auto-detected | ||
based on the file extension (.json, .yaml, .yml). | ||
""" | ||
path = Path(config_path) | ||
|
||
if not path.exists(): | ||
ad_logger.warning(f"Sharding config file not found: {config_path}") | ||
return False | ||
|
||
try: | ||
with open(config_path, "r") as f: | ||
if path.suffix.lower() in [".yaml", ".yml"]: | ||
self.custom_sharding_config = yaml.safe_load(f) | ||
elif path.suffix.lower() == ".json": | ||
self.custom_sharding_config = json.load(f) | ||
else: | ||
ad_logger.warning(f"Unsupported sharding config file format: {path.suffix}") | ||
except Exception as e: | ||
ad_logger.warning(f"Failed to read sharding config file: {e}") | ||
return False | ||
return True | ||
|
||
def append_TP(self, tp_transform: TPShardingInfo) -> bool: | ||
"""Append a TP transform only if that node was | ||
not sharded before. Do not overwrite existing transforms. | ||
""" | ||
for existing_transform in self.tp_transforms: | ||
if existing_transform.target_node == tp_transform.target_node: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this sufficient to avoid conflicting/duplicate configurations? |
||
return False | ||
self.tp_transforms.append(tp_transform) | ||
return True | ||
|
||
def validate_config(self) -> bool: | ||
if self.factory_source != ShardingConfigSource.HUGGINGFACE: | ||
ad_logger.warning( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
"""Tests for basic graph sharding.""" | ||
|
||
# add to the path directory 4 directories up | ||
import os | ||
from functools import partial | ||
from typing import Type | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import yaml | ||
from _dist_test_utils import get_device_counts | ||
from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm | ||
from _model_test_utils import FakeFP8Linear | ||
|
@@ -193,12 +196,22 @@ def verify_local_weight_sizes(gm) -> bool: | |
op_expected = getattr(torch.ops.auto_deploy, dist_op_expected) | ||
|
||
gm = torch_export_to_gm(model, args=(x,), clone=True) | ||
sharding_source = ["custom"] if from_config else ["heuristic"] | ||
|
||
if sharding_source == ["custom"]: | ||
# If the file does not exist, write predefined_config to tp_sharding.yaml file | ||
if not os.path.exists("tp_sharding.yaml"): | ||
with open("tp_sharding.yaml", "w") as f: | ||
yaml.dump(predefined_config, f, sort_keys=False) | ||
Comment on lines
+204
to
+205
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps use python's tempfile, to avoid contaminating the current working dir. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem with tempfile is that anyway this file has to be visible from a different thread directly form a disk, so i cannot use a context like:
Unless, you know a good workaround to it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my other comment here: https://github.com/NVIDIA/TensorRT-LLM/pull/8153/files#r2437473448 If you do you should be able to just provide the custom config as dictionary without needing to create/read a tmp file |
||
gm_transformed = InferenceOptimizer( | ||
None, | ||
{ | ||
"detect_sharding": { | ||
"stage": "sharding", | ||
"use_sharding_from_factory": from_config, | ||
"sharding_source": sharding_source, | ||
"custom_sharding_config": "tp_sharding.yaml", | ||
"support_partial_config": False, | ||
"sharding_dims": ["tp"], | ||
}, | ||
"sharding_transform_executor": { | ||
"stage": "sharding", | ||
|
@@ -338,23 +351,33 @@ def _run_pattern_detection_job( | |
) | ||
) | ||
|
||
sharding_source = ["custom"] if from_config else ["heuristic"] | ||
|
||
if sharding_source == ["custom"]: | ||
# If the file does not exist, write predefined_config to tp_sharding.yaml file | ||
if not os.path.exists("tp_sharding.yaml"): | ||
with open("tp_sharding.yaml", "w") as f: | ||
yaml.dump(predefined_config, f, sort_keys=False) | ||
|
||
# get detected transformations | ||
optimizer = InferenceOptimizer( | ||
None, | ||
{ | ||
"detect_sharding": { | ||
"stage": "sharding", | ||
"use_sharding_from_factory": from_config, | ||
"sharding_source": sharding_source, | ||
"custom_sharding_config": "tp_sharding.yaml", | ||
"support_partial_config": False, | ||
"sharding_dims": ["tp"], | ||
}, | ||
}, | ||
) | ||
optimizer.shared_config.local_rank = rank | ||
optimizer.shared_config.world_size = world_size | ||
optimizer.shared_config.sharding_config.predefined_config = predefined_config | ||
_ = optimizer(None, gm) | ||
detected_transformations = optimizer.shared_config.sharding_config.tp_transforms | ||
|
||
print(f"detected_transformations: {detected_transformations}") | ||
print(f"expected_transformations: {expected_transformations}") | ||
# Run pattern detection test | ||
run_sharding_pattern_detection_test(detected_transformations, expected_transformations) | ||
|
||
|
@@ -409,7 +432,3 @@ def test_sharding_pattern_detection( | |
No need to run distributed job, can be run on single process. | ||
""" | ||
_run_pattern_detection_job(model_cls, bias, 0, world_size, from_config) | ||
|
||
|
||
if __name__ == "__main__": | ||
_run_pattern_detection_job(nn.Linear, False, 0, 8, False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this is just a leftover from testing and should be reverted?