Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ transforms:
simple_shard_only: false
use_sharding_from_factory: false
support_partial_config: false
sharding_dims: ['tp', 'ep', 'bmm']
sharding_dims: ['ep', 'bmm', 'ssm', 'tp']
requires_shape_prop: true
# TODO: (hg) need to ensure run_shape_prop after sharding.
sharding_transform_executor:
Expand Down
16 changes: 15 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from einops import rearrange
from transformers import AutoModelForCausalLM

from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
from tensorrt_llm._torch.auto_deploy.models.patches.bamba import _bamba_mixer_torch_forward


Expand Down Expand Up @@ -79,7 +80,7 @@ def _nemotron_h_block_forward(
elif self.block_type == "attention":
hidden_states = self.mixer(hidden_states, cache_position=cache_position)
hidden_states = hidden_states[0]
elif self.block_type == "mlp":
elif self.block_type in ["mlp", "moe"]:
hidden_states = self.mixer(hidden_states)
else:
raise ValueError(f"Invalid block_type: {self.block_type}")
Expand Down Expand Up @@ -112,6 +113,19 @@ def get_model_from_config_patched(config, **kwargs):
return model


def _set_sharding_config_patched(self, *args, **kwargs):
self._sharding_config["head_dim"] = 128
self._sharding_config["tp_plan"] = {
"in_proj": "mamba",
"out_proj": "rowwise",
"up_proj": "colwise",
"down_proj": "rowwise",
"*": "gather",
}


AutoModelForCausalLMFactory._set_sharding_config = _set_sharding_config_patched
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will require clean-up


# TODO: figure out how this can be incorporated into the export patch system
AutoModelForCausalLM.from_config = get_model_from_config_patched

Expand Down
9 changes: 9 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/transform/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ def __and__(self, other: "TransformInfo") -> "TransformInfo":
has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes,
)

# implement + addition operator for TransformInfo
def __add__(self, other: "TransformInfo") -> "TransformInfo":
return TransformInfo(
skipped=self.skipped and other.skipped,
num_matches=self.num_matches + other.num_matches,
is_clean=self.is_clean and other.is_clean,
has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes,
)


TransformHistory = Dict[str, TransformInfo]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _find_final_hidden_state_node(
if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
return None
index_node = mul_node.args[1]
index_add_node = bfs(
index_add_node, _ = bfs(
index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary
)
if not index_add_node:
Expand Down Expand Up @@ -360,7 +360,7 @@ def target(n: torch.fx.Node) -> bool:
return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0

try:
node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
node_to_remove, _ = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
graph.erase_node(node_to_remove)
return True
except RuntimeError:
Expand Down Expand Up @@ -430,7 +430,7 @@ def _apply(
common_ancessor2 = _find_lowest_common_ancessor(arg2_list)
if not common_ancessor2:
continue
selected_experts = bfs(
selected_experts, _ = bfs(
common_ancessor2,
lambda node: is_op(node, torch.ops.aten.one_hot),
attr_next="all_input_nodes",
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ...shim.interface import CachedSequenceInterface
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.logger import ad_logger
from ...utils.node_utils import extract_param_names_from_lin_node, is_linear_op, is_op
from ...utils.node_utils import extract_param_names_from_node, is_linear_op, is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry


Expand All @@ -36,7 +36,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
y2 = y[:, out1:out1+out2]
"""
# some info we need
keys_unfused = [extract_param_names_from_lin_node(n)[0] for n in linear_nodes]
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
sizes_unfused = [p.size(0) for p in params_unfused]
key_fused = f"fused_weight_{idx}"
Expand Down Expand Up @@ -128,7 +128,7 @@ def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple
def _insert_fused_quant_gemm(
self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]
):
keys_unfused = [extract_param_names_from_lin_node(n)[0] for n in linear_nodes]
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
sizes_unfused = [p.size(0) for p in params_unfused]
key_fused = f"fused_weight_{idx}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import (
extract_param_names_from_lin_node,
extract_param_names_from_node,
get_quantization_params_from_linear_node,
is_bmm_op,
is_linear_op,
Expand Down Expand Up @@ -136,7 +136,7 @@ def _insert_quantized_linear(

The state_dict is also updated to contain the sharded weights.
"""
param_name, _ = extract_param_names_from_lin_node(node)
param_name, _ = extract_param_names_from_node(node)
original_weight = gm.get_parameter(param_name)
new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False)
modname, _, attrname = param_name.rpartition(".")
Expand Down
Loading
Loading