From 1b312b75020cb8d8a0b3ec160062d8e96604dfa2 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Fri, 17 Oct 2025 02:47:22 -0700 Subject: [PATCH 01/12] WiP parallelizing mamba2 layer Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- modelling_nemotron_h.py | 1899 +++++++++++++++++ simple_nemotron.py | 731 +++++++ .../_torch/auto_deploy/config/default.yaml | 4 +- .../auto_deploy/models/patches/nemotron_h.py | 49 +- .../_torch/auto_deploy/utils/node_utils.py | 10 +- .../auto_deploy/utils/sharding_utils.py | 2 + 6 files changed, 2684 insertions(+), 11 deletions(-) create mode 100644 modelling_nemotron_h.py create mode 100644 simple_nemotron.py diff --git a/modelling_nemotron_h.py b/modelling_nemotron_h.py new file mode 100644 index 00000000000..2c51f60b93f --- /dev/null +++ b/modelling_nemotron_h.py @@ -0,0 +1,1899 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch NemotronH model.""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.cache_utils import \ + DynamicCache # we need __iter__ and __len__ of pkv +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (ModelOutput, add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, logging) +from transformers.utils.import_utils import ( + is_causal_conv1d_available, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, is_mamba_2_ssm_available) + +from .configuration_nemotron_h import NemotronHConfig + +logger = logging.get_logger(__name__) + +# Copied from transformers.models.mamba.modeling_mamba2.modeling_mamba2.py with MAMBA2->NEMOTRONH,Mamba2->NemotronH +# For Mamba2 components Mamba2->NemotronHMamba2 +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import \ + selective_state_update + from mamba_ssm.ops.triton.ssd_combined import ( + mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined) +else: + mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None + +try: + #from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated + from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn +except ImportError: + raise ImportError( + "mamba-ssm is required by the Mamba model but cannot be imported") + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import \ + _flash_attention_forward + +is_fast_path_available = all(( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, +)) + +_CHECKPOINT_FOR_DOC = "nvidia/Nemotron-H-56B-Base-8K" +_CONFIG_FOR_DOC = "NemotronHConfig" + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, + 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, + 0) + + return torch.nn.functional.pad(input_tensor, + pad_shape, + mode="constant", + value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, + input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, + input_tensor.shape[2], + input_tensor.shape[3]) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), + chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, + chunk_size, + device=input_tensor.device, + dtype=torch.bool), + diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, + chunk_size, + device=input_tensor.device, + dtype=torch.bool), + diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[ + 1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_override_pattern + self.has_previous_state = False # only used by mamba + intermediate_size = config.mamba_num_heads * config.mamba_head_dim + ssm_state_size = config.ssm_state_size + conv_kernel_size = config.conv_kernel + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "M": + # Mamba layer + self.conv_states += [ + torch.zeros(batch_size, + intermediate_size, + conv_kernel_size, + device=device, + dtype=dtype) + ] + self.ssm_states += [ + torch.zeros(batch_size, + intermediate_size, + ssm_state_size, + device=device, + dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [ + torch.tensor([[]] * batch_size, device=device) + ] + self.ssm_states += [ + torch.tensor([[]] * batch_size, device=device) + ] + self.transformer_layers.append(i) + + self.key_cache = [ + torch.tensor([[]] * batch_size, device=device) + for _ in range(config.num_hidden_layers) + ] + self.value_cache = [ + torch.tensor([[]] * batch_size, device=device) + for _ in range(config.num_hidden_layers) + ] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( + 0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[ + layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[ + layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[ + layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[ + 0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache( + self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError( + "HybridMambaAttentionDynamicCache does not have a legacy cache equivalent." + ) + + @classmethod + def from_legacy_cache( + cls, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "DynamicCache": + raise NotImplementedError( + "HybridMambaAttentionDynamicCache does not have a legacy cache equivalent." + ) + + # Copied from modeling_mamba2.py + def update_conv_state(self, + layer_idx: int, + new_conv_state: torch.Tensor, + cache_init: bool = False) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to( + self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll( + shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to( + self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class MambaRMSNormGated(torch.nn.Module): + + def __init__(self, hidden_size, group_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.group_size = group_size + + # jan28b version + def forward(self, hidden_states, gate=None): + return rmsnorm_fn( + x=hidden_states, + weight=self.weight, + bias=None, # No bias + z=gate, + eps=self.variance_epsilon, + group_size=self.group_size, + norm_before_gate=False) + + +class NemotronHMamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: NemotronHConfig, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_num_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.ssm_state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.activation = config.mamba_hidden_act + self.act = ACT2FN[config.mamba_hidden_act] + + self.layer_norm_epsilon = config.layer_norm_epsilon + + self.n_groups = config.n_groups + self.head_dim = config.mamba_head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.use_bias, + ) + # selective projection used to make dt, B and C input dependent + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated(self.intermediate_size, + eps=self.layer_norm_epsilon, + group_size=self.intermediate_size // + self.n_groups) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, + self.hidden_size, + bias=config.use_bias) + self.use_bias = config.use_bias + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d") + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, + attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - + 2 * self.n_groups * self.ssm_state_size - self.num_heads) // 2 + + # Single step calculations via cache + if cache_params is not None and cache_position is not None and cache_position[ + 0] > 0: + _, _, gate, hidden_states_B_C, dt = projected_states.squeeze( + 1).split([ + d_mlp, d_mlp, self.intermediate_size, self.conv_dim, + self.num_heads + ], + dim=-1) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size, groups_time_state_size, + groups_time_state_size + ], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, + ...][:, :, + None].expand(-1, self.head_dim, + self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, + self.num_heads, + self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, + self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float() + ) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == ( + 0.0, float("inf")) else { + "dt_limit": self.time_step_limit + } + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [ + d_mlp, d_mlp, self.intermediate_size, self.conv_dim, + self.num_heads + ], + dim=-1) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose( + 1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (cache_params.conv_kernel_size - + hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.update_conv_state(layer_idx=self.layer_idx, + new_conv_state=conv_states, + cache_init=True) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose( + 1, 2))[..., :seq_len].transpose(1, 2)) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + hidden_states_B_C = apply_mask_to_padding_states( + hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size, groups_time_state_size, + groups_time_state_size + ], + dim=-1, + ) + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, + new_ssm_state=ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states.device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, + cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[ + 1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * + attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, + attention_mask) + + +class NemotronHRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + NemotronHRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + # Weights are in float32 + return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) + + +class NemotronHBlock(nn.Module): + + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = NemotronHRMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + # M: Mamba2, *: Attention, -: MLP + self.block_type = config.layers_block_type[layer_idx] + if self.block_type == "mamba": + self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx) + elif self.block_type == "attention": + self.mixer = NEMOTRONH_ATTENTION_CLASSES[ + config._attn_implementation](config, layer_idx=layer_idx) + elif self.block_type == "mlp": + self.mixer = NemotronHMLP(config, layer_idx=layer_idx) + elif self.block_type == "moe": + self.mixer = NemotronHMOE(config, layer_idx=layer_idx) + else: + raise ValueError( + f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}" + ) + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)): + # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs + residual = hidden_states + hidden_states = self.norm( + hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + if self.block_type == "mamba": + hidden_states = self.mixer(hidden_states, + cache_params=cache_params, + cache_position=cache_position) + elif self.block_type == "attention": + hidden_states = self.mixer(hidden_states, + cache_position=cache_position) + hidden_states = hidden_states[0] + elif self.block_type in ["mlp", "moe"]: + hidden_states = self.mixer(hidden_states) + else: + raise ValueError(f"Invalid block_type: {self.block_type}") + + hidden_states = residual + hidden_states + return hidden_states + + +# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH +class NemotronHMLP(nn.Module): + + def __init__(self, + config, + intermediate_size=None, + layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class.") + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size or config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, + self.intermediate_size, + bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, + self.hidden_size, + bias=config.mlp_bias) + self.act_fn = ACT2FN[config.mlp_hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +class NemotronHMOE(nn.Module): + + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.experts = nn.ModuleList([ + NemotronHMLP(config, + intermediate_size=config.moe_intermediate_size, + layer_idx=layer_idx) + for _ in range(config.n_routed_experts) + ]) + self.gate = NemotronHTopkRouter(config) + self.shared_experts = NemotronHMLP( + config=config, + intermediate_size=config.moe_shared_expert_intermediate_size, + layer_idx=layer_idx) + + def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, + topk_weights: torch.Tensor): + r""" + CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused + to not have to do a loop here (deepseek has 256 experts soooo yeah). + """ + final_hidden_states = torch.zeros_like(hidden_states, + dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, + num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) + + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) + + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, + weighted_output) + + # in original deepseek, the output of the experts are gathered once we leave this module + # thus the moe module is itelsf an IsolatedParallel module + # and all expert are "local" meaning we shard but we don't gather + return final_hidden_states.type(hidden_states.dtype) + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.moe(hidden_states, topk_indices, + topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +class NemotronHTopkRouter(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, config.hidden_size), + dtype=torch.float32)) + self.register_buffer( + "e_score_correction_bias", + torch.zeros(self.n_routed_experts, dtype=torch.float32)) + + @torch.no_grad() + def get_topk_indices(self, scores): + scores_for_choice = scores.view( + -1, + self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + group_scores = (scores_for_choice.view( + -1, self.n_group, + self.n_routed_experts // self.n_group).topk(2, + dim=-1)[0].sum(dim=-1)) + group_idx = torch.topk(group_scores, + k=self.topk_group, + dim=-1, + sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = (group_mask.unsqueeze(-1).expand( + -1, self.n_group, self.n_routed_experts // self.n_group).reshape( + -1, self.n_routed_experts)) + scores_for_choice = scores_for_choice.masked_fill( + ~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, + k=self.top_k, + dim=-1, + sorted=False)[1] + return topk_indices + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), + self.weight.type(torch.float32)) + scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, + None, :, :].expand(batch, num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, + head_dim) + + +class NemotronHAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, + config: NemotronHConfig, + layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class.") + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + self.head_dim = config.head_dim + else: + self.head_dim = config.hidden_size // self.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias) + self.o_proj = nn.Linear(self.head_dim * self.num_heads, + self.hidden_size, + bias=config.attention_bias) + + def forward( + self, + hidden_states: torch.Tensor, + # position_embeddings: Tuple[torch.Tensor, torch.Tensor], #TODO + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + #attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, + self.num_heads * self.head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba +#class JambaFlashAttention2(JambaAttention): +class NemotronHFlashAttention2(NemotronHAttention): + """ + Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10( + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}.") + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + #attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * + self.head_dim).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba +#class JambaSdpaAttention(JambaAttention): +class NemotronHSdpaAttention(NemotronHAttention): + """ + Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from NemotronHAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "NemotronHModel is using NemotronHSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, :key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +NEMOTRONH_ATTENTION_CLASSES = { + "eager": NemotronHAttention, + "flash_attention_2": NemotronHFlashAttention2, + "sdpa": NemotronHSdpaAttention, +} + + +# Copied from transformers.models.mamba.modeling_mamba2.Mamba2PreTrainedModel +class NemotronHPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = NemotronHConfig + base_model_prefix = "backbone" + _no_split_modules = ["NemotronHBlock"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, NemotronHMamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt = torch.exp( + torch.rand(self.config.mamba_num_heads) * + (math.log(self.config.time_step_max) - + math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min)).clamp( + min=self.config.time_step_floor) + + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_bias.copy_(inv_dt) + module.dt_bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + # TODO: Check + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba2.Mamba2Output with MAMBA2->NemotronH,Mamba2->NemotronH +class NemotronHOutput(ModelOutput): + """ + Class for the NemotronH model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`HybridMambaAttentionDynamicCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[HybridMambaAttentionDynamicCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH +class NemotronHCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`HybridMambaAttentionDynamicCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[HybridMambaAttentionDynamicCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +NEMOTRONH_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`NemotronHConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +NEMOTRONH_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + position_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + cache_params (`HybridMambaAttentionDynamicCache`, *optional*): + If passed along, the model uses the previous state in all the blocks (which will give the output for the + `input_ids` provided as if the model add `state_input_ids + input_ids` as context). + use_cache (`bool`, *optional*): + If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the current input in the cache. This is used to ensure that the cache is correctly updated. + If `cache_params` is passed, `cache_position` should also be passed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) +""" + + +@add_start_docstrings( + "The bare NemotronH Model transformer outputting raw hidden-states without any specific head on top.", + NEMOTRONH_START_DOCSTRING, +) +class NemotronHModel(NemotronHPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([ + NemotronHBlock(config, layer_idx=idx) + for idx in range(config.num_hidden_layers) + ]) + + self.gradient_checkpointing = False + self.norm_f = NemotronHRMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", + "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=NemotronHOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, NemotronHOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + # use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = use_cache if use_cache is not None else ( + self.config.use_cache if not self.training else False) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds + is not None): # ^ is python for xor + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # From zamba_modeling.py + if use_cache and cache_params is None: + logger.warning_once( + "NemotronH requires an initialized `NemotronHHybridDynamicCache` to return a cache. None was " + "provided, so no cache will be returned.") + + hidden_states = inputs_embeds + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], + device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, + cache_position) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + # Until HERE + + for layer_idx, mixer_block in enumerate(self.layers): + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + if mixer_block.block_type == "mamba": + layer_mask = mamba_mask + elif mixer_block.block_type == "attention": + layer_mask = causal_mask + elif mixer_block.block_type in ["mlp", "moe"]: + layer_mask = None + else: + raise ValueError(f"Invalid block_type: {self.block_type}") + + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, + cache_position, layer_mask) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=layer_mask, + ) + + # TODO: Store attentions + # if output_attentions: + # if layer_outputs[1] is not None: + # # append attentions only of attention layers. Mamba layers return `None` as the attention weights + # all_self_attns += (layer_outputs[1],) + + # TODO (Check): should it happen before the forward pass? + # if output_hidden_states: + # all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple( + v for v in [hidden_states, cache_params, all_hidden_states] + if v is not None) + + return NemotronHOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, + device=device) > cache_position.reshape( + -1, 1) + causal_mask = causal_mask[None, + None, :, :].expand(input_tensor.shape[0], 1, + -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone( + ) # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq( + 0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[ + ..., :mask_length].masked_fill(padding_mask, min_dtype) + + if (self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda"): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype) + + return causal_mask + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None + and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +@add_start_docstrings( + """ + The NEMOTRONH Model transformer with a language modeling head on top (linear layer with weights not tied to the input + embeddings). + """, + NEMOTRONH_START_DOCSTRING, +) +class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = NemotronHModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, + config.vocab_size, + bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model + + def set_decoder(self, decoder): + self.model = decoder + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py + # Overwritten -- uses `cache_params` as opposed to `past_key_values` + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if (inputs_embeds is not None # Exception 1 + or cache_position[-1] >= input_ids.shape[1] # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0]:] + elif input_ids.shape[1] != cache_position.shape[ + 0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = { + "input_ids": input_ids.contiguous() + } # `contiguous()` needed for compilation use cases + + model_inputs.update({ + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + }) + return model_inputs + + @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=NemotronHCausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, NemotronHCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + nemotron_h_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = nemotron_h_outputs[0] + + # TODO: Check zamba_modeling.py: https://github.com/huggingface/transformers/blob/d7188ba600e36d3fd191b12e19f1b3bb81a8404f/src/transformers/models/zamba/modeling_zamba.py#L1284C1-L1286C2 + #logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + logits = self.lm_head(hidden_states.to( + self.lm_head.weight.dtype)).float() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1)) + + if not return_dict: + output = (logits, ) + nemotron_h_outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return NemotronHCausalLMOutput( + loss=loss, + logits=logits, + cache_params=nemotron_h_outputs.cache_params, + hidden_states=nemotron_h_outputs.hidden_states, + attentions=nemotron_h_outputs.attentions, + ) diff --git a/simple_nemotron.py b/simple_nemotron.py new file mode 100644 index 00000000000..8313922f3ae --- /dev/null +++ b/simple_nemotron.py @@ -0,0 +1,731 @@ +""" +Simplified NemotronHMamba2Mixer - Tensor Algebra Operations Only + +This file focuses on the tensor algebra operations in the Mamba2 forward pass, +with detailed annotations for parallelization across multiple GPUs. + +Notation: +- b: batch_size +- s: seq_len +- h_in: hidden_size (input) +- h: num_heads +- d: head_dim +- n: ssm_state_size +- g: n_groups +- i: intermediate_size (= h * d) +- c: chunk_size +- num_chunks: number of chunks (= ceil(s / c)) + +Key relationships: +- intermediate_size = num_heads * head_dim (i = h * d) +- conv_dim = intermediate_size + 2 * n_groups * ssm_state_size +""" + +from typing import Optional + +import torch +import torch.nn as nn + + +class NemotronHMamba2Mixer: + """ + Mamba2 SSM Mixer - Tensor Algebra Only + + This class contains only the algebraically significant operations, + annotated with parallelization strategies. + """ + + def __init__(self): + # Model dimensions (example values) + self.hidden_size = 4096 # h_in + self.num_heads = 64 # h + self.head_dim = 64 # d + self.intermediate_size = 4096 # i = h * d + self.n_groups = 8 # g + self.ssm_state_size = 128 # n + self.chunk_size = 256 # c + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + + # analogy to transformers' attention: + # A - query [b, s, h, d] + # B - key [b, s, g, d] # n_groups function as num KV heads + # C - value [b, s, g, d] + # D - attention mask + # B and C will be broadcasted from g to h for SSM computation + + # Learnable parameters + conv_kernel = 4 + self.in_proj = nn.Linear( + self.hidden_size, + self.intermediate_size + self.conv_dim + self.num_heads) + self.conv1d = nn.Conv1d( + self.conv_dim, + self.conv_dim, + kernel_size=conv_kernel, + groups=self.conv_dim, + padding=conv_kernel - + 1 # This ensures output length >= input length + ) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size) + self.A_log = nn.Parameter(torch.randn(self.num_heads)) + self.dt_bias = nn.Parameter(torch.randn(self.num_heads)) + self.D = nn.Parameter(torch.randn(self.num_heads)) + + def segment_sum(self, input_tensor): + """ + Segment sum operation - computes cumulative sum within triangular mask. + + Input: [..., chunk_size] + Output: [..., chunk_size, chunk_size] + + PARALLELIZATION ANALYSIS: + - All batch dimensions (...): FULLY PARALLEL (embarrassingly parallel) + - chunk_size dimension: SEQUENTIAL (cumsum is inherently sequential) + - Can parallelize across chunks if processing multiple chunks + - Cross-GPU: Can distribute batch/head dimensions, but cumsum requires local computation + """ + chunk_size = input_tensor.size(-1) + + # Input: [..., c] -> [..., c, c] + # Complexity: O(c^2) per element in batch + # Parallel: All leading dims are independent + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), + chunk_size) + + # Cumsum along dim=-2 + # Input: [..., c, c], Output: [..., c, c] + # Complexity: O(c^2) per element + # Parallel: Leading dims (...) are independent, but cumsum is sequential in last-2 dim + # WARNING: cumsum is NOT parallelizable in the reduction dimension + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + return tensor_segsum + + def torch_forward_algebra_only( + self, + input_states: torch.Tensor, # [b, s, h_in] + cache_params: Optional = None, + debug: bool = False): + """ + Forward pass with TENSOR ALGEBRA operations only. + Focus: matrix multiplications, reductions, cumulative operations. + Excluded: element-wise ops (activations, exp, masking, etc.) + """ + + batch_size, seq_len, _ = input_states.shape # b, s, h_in + + # ============================================================================= + # STEP 1: Input Projection (Linear Layer) + # ============================================================================= + # Operation: projected_states = input_states @ in_proj.weight^T + in_proj.bias + # Input: [b, s, h_in] + # Weight: [projection_size, h_in] where projection_size = i + conv_dim + h + # Output: [b, s, projection_size] + # Complexity: O(b * s * h_in * projection_size) ≈ O(b * s * h_in^2) + # + # PARALLELIZATION: + # - Batch (b): FULLY PARALLEL - can split across GPUs with no communication + # - Sequence (s): FULLY PARALLEL - can split across GPUs with no communication + # - Hidden (h_in): PARALLEL with ALL_REDUCE - this is the reduction dimension + # * If split h_in across GPUs, need all_reduce to sum partial results + # * Row-parallel: split weight rows, no all_reduce needed + # * Column-parallel: split weight columns, need all_reduce after + # - Output (projection_size): PARALLEL - row-wise split requires no communication + # + # TENSOR PARALLEL STRATEGIES: + # 1. Batch parallel: Each GPU processes different batch elements + # 2. Sequence parallel: Each GPU processes different tokens (works for attention) + # 3. Tensor parallel (column): Split projection_size, all_reduce on h_in + # 4. Tensor parallel (row): Split h_in, each GPU computes partial projection + projected_states = self.in_proj(input_states) # [b, s, projection_size] + + # Split the projection into components + # gate: [b, s, i], hidden_states_B_C: [b, s, conv_dim], dt: [b, s, h] + # Note: d_mlp is computed but will be 0 in this configuration + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - + 2 * self.n_groups * self.ssm_state_size - self.num_heads) // 2 + + if debug: + print("\nProjection split:") + print(f" projected_states shape: {projected_states.shape}") + print(f" d_mlp: {d_mlp}") + print(f" Split sizes: [d_mlp={d_mlp}, d_mlp={d_mlp}, " + f"intermediate={self.intermediate_size}, " + f"conv_dim={self.conv_dim}, num_heads={self.num_heads}]") + print( + f" Total: {2*d_mlp + self.intermediate_size + self.conv_dim + self.num_heads}" + ) + + # Split into components (d_mlp will be 0, so first two splits are empty) + splits = [] + current_idx = 0 + for size in [ + d_mlp, d_mlp, self.intermediate_size, self.conv_dim, + self.num_heads + ]: + if size > 0: + splits.append(projected_states[..., + current_idx:current_idx + size]) + else: + splits.append( + projected_states[..., + current_idx:current_idx]) # Empty tensor + current_idx += size + + _, _, gate, hidden_states_B_C, dt = splits[0], splits[1], splits[ + 2], splits[3], splits[4] + + if debug: + print( + f" After split - gate: {gate.shape}, hidden_states_B_C: {hidden_states_B_C.shape}, dt: {dt.shape}" + ) + + # ============================================================================= + # STEP 2: Conv1D Operation + # ============================================================================= + # Conv1D is applied on sequence dimension + # Input: [b, conv_dim, s] (after transpose) + # Weight: [conv_dim, 1, kernel_size] + # Output: [b, conv_dim, s] + # Complexity: O(b * conv_dim * s * kernel_size) + # + # PARALLELIZATION: + # - Batch (b): FULLY PARALLEL + # - Channel (conv_dim): FULLY PARALLEL (depthwise conv, groups=conv_dim) + # - Sequence (s): PARALLEL with communication + # * Conv requires kernel_size-1 halo elements from neighbors + # * Split sequence: need halo exchange between GPUs + # * First/last kernel_size-1 tokens need data from adjacent GPUs + # + # TENSOR PARALLEL STRATEGIES: + # 1. Batch parallel: Easiest, no communication + # 2. Channel parallel: Split conv_dim, no cross-channel communication (depthwise) + # 3. Sequence parallel: Need halo exchange (communication overhead) + hidden_states_B_C_transposed = hidden_states_B_C.transpose( + 1, 2) # [b, conv_dim, s] + conv_out = self.conv1d(hidden_states_B_C_transposed)[ + ..., :seq_len] # [b, conv_dim, s] + hidden_states_B_C = conv_out.transpose(1, 2) # [b, s, conv_dim] + + # Split conv output + # conv_dim = intermediate_size + 2 * n_groups * ssm_state_size + split_sizes = [ + self.intermediate_size, self.n_groups * self.ssm_state_size, + self.n_groups * self.ssm_state_size + ] + + # Verify split sizes match conv_dim + assert sum( + split_sizes + ) == self.conv_dim, f"Split sizes {split_sizes} don't sum to conv_dim {self.conv_dim}" + + hidden_states = hidden_states_B_C[..., :self.intermediate_size] + B = hidden_states_B_C[..., + self.intermediate_size:self.intermediate_size + + self.n_groups * self.ssm_state_size] + C = hidden_states_B_C[..., self.intermediate_size + + self.n_groups * self.ssm_state_size:] + + # hidden_states: [b, s, i], B: [b, s, g*n], C: [b, s, g*n] + if debug: + print( + f"After split - hidden_states: {hidden_states.shape}, B: {B.shape}, C: {C.shape}" + ) + + # ============================================================================= + # STEP 3: SSM State Space Computation (Main Computation) + # ============================================================================= + + # Reshape for SSM computation + # hidden_states: [b, s, i] -> [b, s, h, d] + # B: [b, s, g*n] -> [b, s, g, n] + # C: [b, s, g*n] -> [b, s, g, n] + # Complexity: O(1) - just view operations + # Parallel: All dimensions are independent + + if debug: + print( + f"Before reshape - hidden_states: {hidden_states.shape}, expected: [{batch_size}, {seq_len}, {self.intermediate_size}]" + ) + print( + f"Reshape target: [{batch_size}, {seq_len}, {self.num_heads}, {self.head_dim}]" + ) + print( + f"intermediate_size={self.intermediate_size}, num_heads={self.num_heads}, head_dim={self.head_dim}" + ) + print(f"num_heads * head_dim = {self.num_heads * self.head_dim}") + + # Verify dimensions are compatible + assert hidden_states.shape[-1] == self.num_heads * self.head_dim, \ + f"Cannot reshape {hidden_states.shape} to have {self.num_heads} heads of dim {self.head_dim}" + + hidden_states = hidden_states.reshape(batch_size, seq_len, + self.num_heads, self.head_dim) + B = B.reshape(batch_size, seq_len, self.n_groups, self.ssm_state_size) + C = C.reshape(batch_size, seq_len, self.n_groups, self.ssm_state_size) + + if debug: + print( + f"After reshape - hidden_states: {hidden_states.shape}, B: {B.shape}, C: {C.shape}" + ) + + # Repeat B and C to match num_heads (from n_groups) + # Input: [b, s, g, n] + # Output: [b, s, h, n] where h = g * repetition_factor + # Complexity: O(b * s * h * n) memory, O(1) compute (just indexing) + # Parallel: FULLY PARALLEL - simple replication + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) # [b, s, h, n] + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) # [b, s, h, n] + + # Compute pad size for chunking + pad_size = (self.chunk_size - + seq_len % self.chunk_size) % self.chunk_size + + # ============================================================================= + # STEP 3a: Chunk Reshaping + # ============================================================================= + # Reshape sequences into chunks + # Input: [b, s, h, d] + # Output: [b, num_chunks, c, h, d] where num_chunks = ceil(s/c) + # Complexity: O(1) - reshape only + # Parallel: All dimensions independent + # + # Note: This creates a new dimension (num_chunks) that can be parallelized! + def reshape_into_chunks(tensor, pad_size, chunk_size): + """Pad and reshape into chunks""" + # Pad: increases sequence length by pad_size + # Reshape: [b, s+pad, ...] -> [b, num_chunks, c, ...] + # Parallel: Independent across batch dimension + return tensor # Simplified - actual implementation in original code + + # After chunking (conceptual): + # hidden_states: [b, num_chunks, c, h, d] + # A: [b, h, num_chunks, c] (permuted for computation) + # B: [b, num_chunks, c, h, n] + # C: [b, num_chunks, c, h, n] + + # ============================================================================= + # STEP 3b: Cumulative Sum (Sequential Operation) + # ============================================================================= + # A_cumsum = torch.cumsum(A, dim=-1) + # Input: [b, h, num_chunks, c] + # Output: [b, h, num_chunks, c] + # Complexity: O(b * h * num_chunks * c) + # + # PARALLELIZATION: + # - Batch (b): FULLY PARALLEL + # - Heads (h): FULLY PARALLEL + # - Chunks (num_chunks): FULLY PARALLEL - each chunk is independent! + # - Chunk_size (c): SEQUENTIAL - cumsum is inherently sequential + # + # CRITICAL: cumsum within each chunk is sequential, but different chunks + # can be computed in parallel! This is why chunking is valuable. + # + # TENSOR PARALLEL: Can split b, h, num_chunks across GPUs with NO communication + A_cumsum = torch.zeros(batch_size, self.num_heads, + (seq_len + pad_size) // self.chunk_size, + self.chunk_size) + + # ============================================================================= + # STEP 3c: Segment Sum (calls cumsum internally) + # ============================================================================= + # L = torch.exp(segment_sum(A)) + # segment_sum input: [b, h, num_chunks, c] + # segment_sum output: [b, h, num_chunks, c, c] + # Complexity: O(b * h * num_chunks * c^2) + # + # PARALLELIZATION: + # - Batch (b): FULLY PARALLEL + # - Heads (h): FULLY PARALLEL + # - Chunks (num_chunks): FULLY PARALLEL + # - Within chunk (c): SEQUENTIAL (cumsum) + # - Output dimension (c): Creates new dimension, parallel + # + # TENSOR PARALLEL: Can split b, h, num_chunks across GPUs with NO communication + # The c x c matrix per chunk is computed locally on each GPU + L = self.segment_sum(A_cumsum) # [b, h, num_chunks, c, c] + + # ============================================================================= + # STEP 3d: Attention-like Computation (G matrix) + # ============================================================================= + # G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] + # Input C: [b, num_chunks, c, h, n] + # Input B: [b, num_chunks, c, h, n] + # Output: [b, num_chunks, c, c, h, n] + # Then: G = G_intermediate.sum(dim=-1) -> [b, num_chunks, c, c, h] + # Complexity: O(b * num_chunks * c^2 * h * n) + # + # PARALLELIZATION: + # - Batch (b): FULLY PARALLEL + # - Chunks (num_chunks): FULLY PARALLEL + # - Query positions (c): FULLY PARALLEL + # - Key positions (c): FULLY PARALLEL + # - Heads (h): FULLY PARALLEL + # - State dimension (n): PARALLEL with ALL_REDUCE (this is reduction dim) + # * If split n across GPUs, need all_reduce after sum + # + # TENSOR PARALLEL STRATEGIES: + # 1. Split any of (b, num_chunks, h) with no communication + # 2. Split n with all_reduce after reduction + # 3. This is similar to attention QK^T computation! + C_expanded = torch.zeros(batch_size, + (seq_len + pad_size) // self.chunk_size, + self.chunk_size, self.chunk_size, + self.num_heads, self.ssm_state_size) + B_expanded = torch.zeros(batch_size, + (seq_len + pad_size) // self.chunk_size, + self.chunk_size, self.chunk_size, + self.num_heads, self.ssm_state_size) + G_intermediate = C_expanded * B_expanded # [b, num_chunks, c, c, h, n] + + # Reduction over state dimension + # Input: [b, num_chunks, c, c, h, n] + # Output: [b, num_chunks, c, c, h] + # Complexity: O(b * num_chunks * c^2 * h * n) + # Parallel: Reduction over n - if n is split, need all_reduce + G = G_intermediate.sum(dim=-1) # [b, num_chunks, c, c, h] + + # ============================================================================= + # STEP 3e: Attention Weights Computation (M matrix) + # ============================================================================= + # M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + # After permute, L: [b, num_chunks, c, c, h] + # G: [b, num_chunks, c, c, h] + # M_intermediate: [b, num_chunks, c, c, h, d] (after broadcasting) + # M = M_intermediate.sum(dim=-1) -> [b, num_chunks, c, c, h, d] + # Complexity: O(b * num_chunks * c^2 * h * d) + # + # PARALLELIZATION: + # - All of (b, num_chunks, c, c, h, d) are independent in the outer product + # - The sum reduction is over a broadcasted dimension + # - FULLY PARALLEL across b, num_chunks, c, c, h + # - d dimension: depends on reduction + L_permuted = L.permute(0, 2, 3, 4, 1) # [b, num_chunks, c, c, h] + M_intermediate = torch.zeros(batch_size, + (seq_len + pad_size) // self.chunk_size, + self.chunk_size, self.chunk_size, + self.num_heads, self.head_dim) + M = M_intermediate.sum( + dim=-1) # Simplified - actual computation more complex + + # ============================================================================= + # STEP 3f: Intra-chunk Output (Y_diag) + # ============================================================================= + # Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + # M: [b, num_chunks, c, c, h, d] + # hidden_states after chunking: [b, num_chunks, c, h, d] + # After broadcasting: [b, num_chunks, c, c, h, d] + # Output after sum: [b, num_chunks, c, h, d] + # Complexity: O(b * num_chunks * c^2 * h * d) + # + # PARALLELIZATION: + # - Batch (b): FULLY PARALLEL + # - Chunks (num_chunks): FULLY PARALLEL + # - Output positions (c, dim=2): FULLY PARALLEL + # - Input positions (c, dim=3): PARALLEL with ALL_REDUCE (reduction dimension) + # - Heads (h): FULLY PARALLEL + # - Head_dim (d): FULLY PARALLEL + # + # This is essentially the attention "apply to values" step! + # TENSOR PARALLEL: Split b, num_chunks, h with no communication + # If split input c dimension, need all_reduce + Y_diag = torch.zeros(batch_size, + (seq_len + pad_size) // self.chunk_size, + self.chunk_size, self.num_heads, self.head_dim) + + # ============================================================================= + # STEP 3g: Intra-chunk State Computation + # ============================================================================= + # B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + # states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + # B_decay: [b, num_chunks, c, h, n] + # hidden_states: [b, num_chunks, c, h, d] + # After broadcasting: [b, num_chunks, c, h, d, n] + # After sum over c: [b, num_chunks, h, d, n] + # Complexity: O(b * num_chunks * c * h * d * n) + # + # PARALLELIZATION: + # - Batch (b): FULLY PARALLEL + # - Chunks (num_chunks): FULLY PARALLEL (each chunk's state independent) + # - Sequence within chunk (c, dim=2): PARALLEL with ALL_REDUCE (reduction) + # - Heads (h): FULLY PARALLEL + # - Head_dim (d): FULLY PARALLEL + # - State_size (n): FULLY PARALLEL + # + # TENSOR PARALLEL: Can split b, num_chunks, h, d, n with no communication + # If split c dimension, need all_reduce after sum + states = torch.zeros(batch_size, + (seq_len + pad_size) // self.chunk_size, + self.num_heads, self.head_dim, self.ssm_state_size) + + # ============================================================================= + # STEP 3h: Inter-chunk Recurrence (Sequential Across Chunks!) + # ============================================================================= + # decay_chunk = torch.exp(segment_sum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + # decay_chunk = decay_chunk.transpose(1, 3) # [b, num_chunks+1, num_chunks+1, h] + # new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + # + # Input states: [b, num_chunks, h, d, n] + # decay_chunk: [b, num_chunks+1, num_chunks+1, h] + # new_states: [b, num_chunks+1, h, d, n] + # + # Complexity: O(b * num_chunks^2 * h * d * n) + # + # PARALLELIZATION - CRITICAL INSIGHT: + # - Batch (b): FULLY PARALLEL + # - Heads (h): FULLY PARALLEL + # - Head_dim (d): FULLY PARALLEL + # - State_size (n): FULLY PARALLEL + # - Chunks (num_chunks): SEQUENTIAL!!! This is a recurrence across chunks! + # + # **This is the main sequential bottleneck for long sequences!** + # + # The sum over dim=1 creates a dependency between chunks: + # new_states[chunk_i] depends on states[0:i] + # + # TENSOR PARALLEL STRATEGIES: + # 1. Can split b, h, d, n across GPUs with no communication + # 2. CANNOT efficiently parallelize across chunks without changing algorithm + # 3. For very long sequences, this becomes a bottleneck + # 4. Possible solution: Use ring-reduce or prefix-sum parallel algorithms + # but this requires O(log num_chunks) communication rounds + # + # Alternative: Pipeline parallelism - process chunks sequentially but + # overlap computation of different layers + new_states = torch.zeros(batch_size, + (seq_len + pad_size) // self.chunk_size + 1, + self.num_heads, self.head_dim, + self.ssm_state_size) + + # Extract final state and intermediate states + states = new_states[:, :-1] # [b, num_chunks, h, d, n] + ssm_state = new_states[:, -1] # [b, h, d, n] - final state for caching + + # ============================================================================= + # STEP 3i: State to Output (Y_off) + # ============================================================================= + # C_times_states = (C[..., None, :] * states[:, :, None, ...]) + # Input C: [b, num_chunks, c, h, n] + # Input states: [b, num_chunks, h, d, n] + # After broadcast: [b, num_chunks, c, h, d, n] + # Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + # After sum over n: [b, num_chunks, c, h, d] + # Complexity: O(b * num_chunks * c * h * d * n) + # + # PARALLELIZATION: + # - Batch (b): FULLY PARALLEL + # - Chunks (num_chunks): FULLY PARALLEL (using precomputed states) + # - Positions (c): FULLY PARALLEL + # - Heads (h): FULLY PARALLEL + # - Head_dim (d): FULLY PARALLEL + # - State_size (n): PARALLEL with ALL_REDUCE (reduction dimension) + # + # TENSOR PARALLEL: Split b, num_chunks, c, h, d with no communication + # Split n requires all_reduce after sum + Y_off = torch.zeros(batch_size, (seq_len + pad_size) // self.chunk_size, + self.chunk_size, self.num_heads, self.head_dim) + + # ============================================================================= + # STEP 3j: Combine Intra-chunk and Inter-chunk Outputs + # ============================================================================= + # y = Y_diag + Y_off + # Both: [b, num_chunks, c, h, d] + # Output: [b, num_chunks, c, h, d] -> [b, s, h, d] -> [b, s, i] + # Complexity: O(b * s * i) for reshape + # Parallel: FULLY PARALLEL (element-wise addition) + y = Y_diag + Y_off # [b, num_chunks, c, h, d] + y = y.reshape(batch_size, -1, self.num_heads, + self.head_dim) # [b, s_padded, h, d] + y = y[:, :seq_len, :, :] # Remove padding: [b, s, h, d] + y = y.reshape(batch_size, seq_len, self.intermediate_size) # [b, s, i] + + # ============================================================================= + # STEP 4: Output Projection (Linear Layer) + # ============================================================================= + # contextualized_states = y @ out_proj.weight^T + out_proj.bias + # Input: [b, s, i] + # Weight: [h_in, i] + # Output: [b, s, h_in] + # Complexity: O(b * s * i * h_in) + # + # PARALLELIZATION: + # - Batch (b): FULLY PARALLEL + # - Sequence (s): FULLY PARALLEL + # - Input dim (i): PARALLEL with ALL_REDUCE (reduction dimension) + # - Output dim (h_in): FULLY PARALLEL (row-parallel) + # + # TENSOR PARALLEL STRATEGIES: + # 1. Column parallel on i: split weight columns, all_reduce after matmul + # 2. Row parallel on h_in: split weight rows, no all_reduce needed + # 3. Typically: in_proj is column-parallel, out_proj is row-parallel + # This minimizes communication (1 all_reduce per layer) + contextualized_states = self.out_proj(y) # [b, s, h_in] + + return contextualized_states + + def summarize_parallelization_strategies(self): + """ + SUMMARY OF PARALLELIZATION STRATEGIES FOR MULTI-GPU DEPLOYMENT + ================================================================ + + DIMENSIONS AND THEIR PARALLELIZABILITY: + + 1. BATCH (b) - EMBARRASSINGLY PARALLEL + - Can split across GPUs with ZERO communication + - Each GPU processes different examples + - Strategy: Data Parallelism + + 2. SEQUENCE (s) - MOSTLY PARALLEL with caveats + - Linear layers: FULLY PARALLEL + - Conv1d: Needs halo exchange (kernel_size-1 elements) + - Attention-like ops: FULLY PARALLEL + - Within chunks: FULLY PARALLEL + - Across chunks: SEQUENTIAL (recurrence) + - Strategy: Sequence Parallelism (limited by chunk recurrence) + + 3. HEADS (h) - FULLY PARALLEL + - All operations independent across heads + - No communication needed + - Strategy: Tensor Parallelism on head dimension + + 4. HEAD_DIM (d) - PARALLEL (no reductions in this dim) + - Can split with no all_reduce + - Strategy: Tensor Parallelism on head_dim + + 5. HIDDEN_DIM (h_in, i) - PARALLEL with ALL_REDUCE + - Linear layers: reduction dimension + - Need all_reduce when splitting this dimension + - Strategy: Tensor Parallelism (column-parallel in, row-parallel out) + + 6. STATE_SIZE (n) - PARALLEL with ALL_REDUCE + - Reduction dimension in attention-like operations + - Need all_reduce when computing G and Y_off + - Strategy: Tensor Parallelism on state dimension + + 7. NUM_CHUNKS - MOSTLY PARALLEL + - Each chunk computation: FULLY PARALLEL + - Chunk recurrence (Step 3h): SEQUENTIAL + - Strategy: Pipeline or sequential processing + + 8. CHUNK_SIZE (c) - MIXED + - Cumsum/segment_sum: SEQUENTIAL within chunk + - Other ops: PARALLEL + - Cannot split within chunk effectively + + RECOMMENDED MULTI-GPU STRATEGIES: + ================================== + + Strategy 1: TENSOR + DATA PARALLEL (Most Common) + ------------------------------------------------- + - Split batch across data-parallel GPUs (no communication) + - Within each data-parallel group, use tensor parallelism: + * Split num_heads across GPUs (no communication in compute) + * Column-parallel in_proj, row-parallel out_proj (1 all_reduce per layer) + - Works well for moderate sequence lengths + - Communication: O(b * s * h_in) per layer for all_reduce + + Strategy 2: SEQUENCE PARALLEL (For Very Long Sequences) + -------------------------------------------------------- + - Split sequence dimension across GPUs + - Requires: + * Halo exchange for conv1d (small overhead) + * Sequential processing of chunk recurrence (pipelined) + - Best for: seq_len >> hidden_size + - Communication: O(conv_kernel * features) for halo + pipeline latency + + Strategy 3: EXPERT PARALLEL (If MOE layers present) + --------------------------------------------------- + - Not shown in this code, but relevant for full model + - Split experts across GPUs + - All-to-all communication for routing + + Strategy 4: PIPELINE PARALLEL (For Very Large Models) + ----------------------------------------------------- + - Split layers across GPUs + - Process micro-batches in pipeline + - Communication: O(b * s * h_in) per pipeline stage boundary + + CRITICAL BOTTLENECKS: + ===================== + + 1. CHUNK RECURRENCE (Step 3h) + - Sequential across chunks + - Cannot parallelize without algorithmic changes + - For long sequences with many chunks, this limits speedup + - Mitigation: Use larger chunk_size (but increases memory) + + 2. CUMSUM OPERATIONS + - Sequential within each chunk + - Limits parallelism to chunk_size granularity + - Cannot split chunk_size dimension across GPUs + + 3. ALL_REDUCE COMMUNICATION + - Required when splitting reduction dimensions + - Latency increases with number of GPUs + - Bandwidth-bound for large tensors + + 4. CONV1D HALO EXCHANGE + - Required for sequence parallelism + - Small overhead but adds latency + + OPTIMAL CONFIGURATION (Example for 8 GPUs): + =========================================== + - Use 4-way tensor parallelism on heads (split 64 heads -> 16 per GPU) + - Use 2-way data parallelism on batch + - Keep sequence on single GPU (if possible) + - If sequence too long: + * Use sequence parallelism with 2-4 way split + * Accept chunk recurrence as sequential bottleneck + + This gives: + - ~4x speedup from tensor parallelism (limited by all_reduce) + - ~2x speedup from data parallelism (perfect scaling) + - Total: ~6-7x speedup on 8 GPUs (75-85% efficiency) + """ + + +def main(): + """ + Example usage showing the tensor shapes through the forward pass. + """ + import sys + + # Check if debug mode is requested + debug = "--debug" in sys.argv + + mixer = NemotronHMamba2Mixer() + + # Example input + batch_size = 4 + seq_len = 1024 + hidden_size = 4096 + + input_states = torch.randn(batch_size, seq_len, hidden_size) + + # Forward pass + output = mixer.torch_forward_algebra_only(input_states, debug=debug) + + print(f"\n{'='*80}") + print("NEMOTRON-H MAMBA2 MIXER - TENSOR ALGEBRA ANALYSIS") + print(f"{'='*80}") + print(f"\nInput shape: {input_states.shape}") + print(f"Output shape: {output.shape}") + print("\nConfiguration:") + print(f" - Batch size: {batch_size}") + print(f" - Sequence length: {seq_len}") + print(f" - Hidden size: {hidden_size}") + print(f" - Num heads: {mixer.num_heads}") + print(f" - Head dim: {mixer.head_dim}") + print(f" - Intermediate size: {mixer.intermediate_size}") + print(f" - Chunk size: {mixer.chunk_size}") + print( + f" - Num chunks: {(seq_len + mixer.chunk_size - 1) // mixer.chunk_size}" + ) + print(f"\n{'='*80}") + print("See docstrings in the code for detailed parallelization analysis.") + print("Run with --debug flag to see intermediate tensor shapes.") + print(f"{'='*80}\n") + + +if __name__ == "__main__": + main() diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 13f1cf0703f..1acdc118fdf 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -75,8 +75,8 @@ transforms: detect_sharding: stage: sharding simple_shard_only: false - use_sharding_from_factory: false - support_partial_config: false + use_sharding_from_factory: true + support_partial_config: true sharding_dims: ['tp', 'ep', 'bmm'] requires_shape_prop: true # TODO: (hg) need to ensure run_shape_prop after sharding. diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py index 396711bd80c..3da0a8f5baa 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py @@ -9,8 +9,14 @@ from einops import rearrange from transformers import AutoModelForCausalLM +from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory from tensorrt_llm._torch.auto_deploy.models.patches.bamba import _bamba_mixer_torch_forward +# from transformers.models.nemotron_h.configuration_nemotron_h import NemotronHConfig + +# Remove this patch after TRT-LLM upgrades to the HF transformers version >= 4.57 +# NemotronHConfig.base_model_tp_plan["layers.*.mlp.c_proj"] = "rowwise" + # Forked from: # https://github.com/state-spaces/mamba/blob/6b32be06d026e170b3fdaf3ae6282c5a6ff57b06/mamba_ssm/ops/triton/layernorm_gated.py @@ -79,7 +85,7 @@ def _nemotron_h_block_forward( elif self.block_type == "attention": hidden_states = self.mixer(hidden_states, cache_position=cache_position) hidden_states = hidden_states[0] - elif self.block_type == "mlp": + elif self.block_type in ["mlp", "moe"]: hidden_states = self.mixer(hidden_states) else: raise ValueError(f"Invalid block_type: {self.block_type}") @@ -88,6 +94,34 @@ def _nemotron_h_block_forward( return hidden_states +# TODO: we assume experts have no bias for now +def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor): + """ + Uses NemotronH router (returns indices, weights) and dispatches through auto_deploy::torch_moe_nemo + with act_fn='relu2'. Falls back to original forward if any expert has bias. + """ + + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + x_flat = hidden_states.view(-1, hidden_states.shape[-1]) + + out_flat = torch.ops.auto_deploy.torch_moe( + x_flat, + topk_indices, + topk_weights, + w1_weight=[e.up_proj.weight for e in self.experts], + w2_weight=[e.down_proj.weight for e in self.experts], + w3_weight=[], + act_fn="relu2", + mlp_style="mlp", + ) + + out = out_flat.view(*orig_shape) + out = out + self.shared_experts(residuals) + return out + + _from_config_original = AutoModelForCausalLM.from_config CUSTOM_MODULE_PATCHES: Dict[str, List[Tuple[str, Callable]]] = { @@ -97,6 +131,7 @@ def _nemotron_h_block_forward( ("_update_mamba_mask", _nemotron_h_model_update_mamba_mask), ], "NemotronHBlock": [("forward", _nemotron_h_block_forward)], + "NemotronHMOE": [("forward", _nemotron_h_moe_forward)], } @@ -112,6 +147,18 @@ def get_model_from_config_patched(config, **kwargs): return model +def _set_sharding_config_patched(self, *args, **kwargs): + self._sharding_config["head_dim"] = 128 + self._sharding_config["tp_plan"] = { + "in_proj": "colwise_fused[21504,21504,512,512,32]", + "conv1d": "colwise_fused[21504,512,512]", + "out_proj": "rowwise", + "*": "gather", + } + + +AutoModelForCausalLMFactory._set_sharding_config = _set_sharding_config_patched + # TODO: figure out how this can be incorporated into the export patch system AutoModelForCausalLM.from_config = get_model_from_config_patched diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index bc454e69396..c34e2bfb634 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -239,12 +239,6 @@ def filtered_nodes( for node in nodes: if target(node): yield node - elif isinstance(target, Iterable) and all(isinstance(t, Callable) for t in target): - for node in nodes: - for t in target: - if t(node): - yield node - break else: # Handle the case where target or ops contains operations operations = ops if ops is not None else target @@ -468,7 +462,7 @@ def predecessors( continue if (not include) or (include and include(arg)): preds.append(arg) - return list(reversed(preds)) + return preds.reverse() def successors( @@ -491,4 +485,4 @@ def successors( continue if (not include) or (include and include(user)): succs.append(user) - return list(reversed(succs)) + return succs.reverse() diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index a95045b7d28..491c15cf5eb 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -248,6 +248,8 @@ class TPShardingInfo(ShardingTransformInfo): split_dim: SplitDimension dist_op: Optional[Literal["all_reduce", "all_gather"]] = None min_local_shape: int = 1 + # used for TP sharding of fused weights + fused_weights: Optional[list] @classmethod def from_node(cls, node: Node, **kwargs) -> "TPShardingInfo": From 533f70951f735d81528a7d6741a72130b6a020a5 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Fri, 17 Oct 2025 02:52:30 -0700 Subject: [PATCH 02/12] fixed preds logic Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index c34e2bfb634..441b456ec54 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -462,7 +462,7 @@ def predecessors( continue if (not include) or (include and include(arg)): preds.append(arg) - return preds.reverse() + return list(reversed(preds)) def successors( @@ -485,4 +485,4 @@ def successors( continue if (not include) or (include and include(user)): succs.append(user) - return succs.reverse() + return list(reversed(succs)) From a21642cefb84b8f259dc20c46ffc55c71aa3d5ad Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Mon, 20 Oct 2025 06:11:27 -0700 Subject: [PATCH 03/12] WiP sharding nemotron Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../auto_deploy/models/patches/nemotron_h.py | 7 +- .../transform/library/fused_moe.py | 6 +- .../auto_deploy/transform/library/sharding.py | 35 ++ .../_torch/auto_deploy/utils/node_utils.py | 134 ++++- .../auto_deploy/utils/sharding_utils.py | 470 ++++++++++++++---- 5 files changed, 513 insertions(+), 139 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py index 3da0a8f5baa..b46d48848f4 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py @@ -150,10 +150,11 @@ def get_model_from_config_patched(config, **kwargs): def _set_sharding_config_patched(self, *args, **kwargs): self._sharding_config["head_dim"] = 128 self._sharding_config["tp_plan"] = { - "in_proj": "colwise_fused[21504,21504,512,512,32]", - "conv1d": "colwise_fused[21504,512,512]", + "in_proj": 'mamba("fused_weight_dims" = {"in_proj": [8192,8192,1024,1024,128], "conv1d": [8192, 1024, 1024]})', "out_proj": "rowwise", - "*": "gather", + "up_proj": "colwise", + "down_proj": "rowwise", + # "*": "gather", } diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 669357b1399..6980b5dadc3 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -294,7 +294,7 @@ def _find_final_hidden_state_node( if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2): return None index_node = mul_node.args[1] - index_add_node = bfs( + index_add_node, _ = bfs( index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary ) if not index_add_node: @@ -360,7 +360,7 @@ def target(n: torch.fx.Node) -> bool: return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0 try: - node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary) + node_to_remove, _ = bfs(start_boundary, target, attr_next="users", boundary=end_boundary) graph.erase_node(node_to_remove) return True except RuntimeError: @@ -430,7 +430,7 @@ def _apply( common_ancessor2 = _find_lowest_common_ancessor(arg2_list) if not common_ancessor2: continue - selected_experts = bfs( + selected_experts, _ = bfs( common_ancessor2, lambda node: is_op(node, torch.ops.aten.one_hot), attr_next="all_input_nodes", diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index d6de54b22d2..f640f09e159 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -16,6 +16,7 @@ happens automatically via the checkpoint loading hook added in step 2c. """ +import ast import operator import re from collections import defaultdict @@ -38,6 +39,7 @@ from ...utils.sharding_utils import ( BMMShardingInfo, EPShardingInfo, + LayerType, ShardingConfig, ShardingTransformInfo, SplitDimension, @@ -260,6 +262,7 @@ def detect_sharding_from_factory_config( # 4. the allowed values are: # - "colwise" # - "rowwise" + # - "mamba" # - "sequence_parallel" # - "local_colwise" # - "local_rowwise" @@ -313,6 +316,24 @@ def detect_sharding_from_factory_config( num_shards += 1 # we have a match. Get the config for this layer config = tp_plan[key] + # check if config has parameters. + if "(" in config: + config, params_str = config.split("(", 1) + params_str = params_str.rsplit(")", 1)[0] # Remove trailing ) + + try: + # Convert "key" = value to "key": value format for dict parsing + params_str = params_str.replace(" = ", ": ") + # Wrap in braces to make it a dict and parse + config_params = ast.literal_eval("{" + params_str + "}") + except Exception as e: + ad_logger.warning( + f"Failed to parse config params: {params_str}, error: {e}. " + "Using empty config." + ) + config_params = {} + else: + config_params = {} if config == "colwise": sharding_config.tp_transforms.append( TPShardingInfo.from_node( @@ -336,6 +357,20 @@ def detect_sharding_from_factory_config( ) ) num_row_col_shards += 1 + elif config == "mamba": + sharding_config.tp_transforms.append( + TPShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + layer_type=LayerType.MAMBA, + fused_weight_dims=config_params.get("fused_weight_dims"), + ) + ) + num_row_col_shards += 1 elif "sequence" in config: # TODO: Sequence parallelism is not supported yet. ad_logger.warning("Sequence parallelism is not supported yet. Skipping.") diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 441b456ec54..da9584c1f08 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -360,22 +360,37 @@ def identify_regions_between_residuals(gm: GraphModule) -> List[Node]: def bfs( - node: Node, target: Callable, attr_next: str = "users", boundary: Optional[Node] = None -) -> Node: - queue = [node] + node: Node, + target: Callable, + attr_next: str = "users", + boundary: Optional[Node] = None, + include_root: bool = True, +) -> Tuple[Node, int]: + """ + Breadth-first search of the graph. + Returns the found node and the depth of the node. + """ + depth = 0 + queue_at_depth = [node] + queue_at_depth_next = [] visited = set() - while queue: - cur_node = queue.pop(0) + while queue_at_depth or queue_at_depth_next: + cur_node = queue_at_depth.pop(0) if boundary is not None and cur_node == boundary: continue # Skip the boundary node. - if target(cur_node): - return cur_node + if target(cur_node) and (include_root or depth > 0): + return cur_node, depth for next_node in getattr(cur_node, attr_next): if boundary is not None and next_node == boundary: continue # Do not expand past the boundary. if next_node not in visited: visited.add(next_node) - queue.append(next_node) + queue_at_depth_next.append(next_node) + if not queue_at_depth: + queue_at_depth = queue_at_depth_next + queue_at_depth_next = [] + depth += 1 + raise RuntimeError(f"Could not find node with target condition {target}.") @@ -450,19 +465,19 @@ def predecessors( If exclude is provided, exclude nodes that satisfy the condition. """ preds = [] + seen = set() for arg in node.args: if isinstance(arg, Node): + if ((not include) or (include and include(arg))) and (not exclude or not exclude(arg)): + if arg not in seen: + preds.append(arg) + seen.add(arg) if depth > 1: - preds.extend(predecessors(arg, depth - 1, include, exclude)) - # add node arg if either: - # a) include and exclude are not specified - # b) include is specified and arg satisfies include condition - # c) exclude is specified and arg does not satisfy exclude condition - if exclude and exclude(arg): - continue - if (not include) or (include and include(arg)): - preds.append(arg) - return list(reversed(preds)) + for p in predecessors(arg, depth - 1, include, exclude): + if p not in seen: + preds.append(p) + seen.add(p) + return preds def successors( @@ -477,12 +492,83 @@ def successors( If exclude is provided, exclude nodes that satisfy the condition. """ succs = [] + seen = set() for user in node.users: + if ((not include) or (include and include(user))) and (not exclude or not exclude(user)): + if user not in seen: + succs.append(user) + seen.add(user) if depth > 1: - succs.extend(successors(user, depth - 1, include, exclude)) - # analogous logic to predecessors - if exclude and exclude(user): + for s in successors(user, depth - 1, include, exclude): + if s not in seen: + succs.append(s) + seen.add(s) + return succs + + +def subgraph( + sources: list[Node], + sinks: list[Node], + include_boundary_nodes: bool = True, + include: Optional[Callable[[Node], bool]] = None, + exclude: Optional[Callable[[Node], bool]] = None, +) -> List[Node]: + """ + Returns a list of nodes in a subgraph in computation DAG defined as all nodes + succeeding any of the node in sources and preceding any of the nodes in sinks. + It is built by a BFS traversal from sinks, where the sources list acts as a + boundary. We do it in this order (and not from sources to sinks) to include + nodes like weights or other inputs (they are not successors of sinks, so otherwise + they wouldn't be included). + + Optionally, include or exclude conditions may be specified to include [exclude] + only nodes that meet [don't meet] certain condition. + """ + subgraph_nodes = [] + seen = set() + queue = list(sinks) + sources_set = set(sources) + + # Initialize queue with sinks and mark them as seen + for node in sinks: + if node not in seen: + seen.add(node) + + # BFS traversal from sinks backwards through predecessors + while queue: + node = queue.pop(0) + + # Check if node should be included based on filters + should_include = True + if include is not None and not include(node): + should_include = False + if exclude is not None and exclude(node): + should_include = False + if not include_boundary_nodes and (node in sources_set) or (node in sinks): + should_include = False + + if should_include: + subgraph_nodes.append(node) + + # Stop traversal at source nodes (boundary) - don't explore their predecessors + if node in sources_set: continue - if (not include) or (include and include(user)): - succs.append(user) - return list(reversed(succs)) + + # Traverse to predecessor nodes (all inputs to this node) + for arg in node.args: + if isinstance(arg, Node) and arg not in seen: + seen.add(arg) + queue.append(arg) + + return subgraph_nodes + + +def draw_graph(gm: GraphModule, filename: str): + """ + Dump graphmodule to SVG file using PyTorch's built-in drawer. + """ + from torch.fx.passes.graph_drawer import FxGraphDrawer + + drawer = FxGraphDrawer(gm, filename) + with open(f"{filename}.svg", "wb") as f: + f.write(drawer.get_dot_graph().create_svg()) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 491c15cf5eb..f81ab0e488c 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -2,10 +2,11 @@ import math import operator +import re from abc import ABC, abstractmethod -from enum import IntEnum +from enum import Enum, IntEnum from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -14,7 +15,15 @@ from ..models.factory import ShardingConfigSource from ..utils.logger import ad_logger -from .node_utils import extract_param_names_from_lin_node, is_op, num_users_of_weight_node +from .node_utils import ( + bfs, + extract_param_names_from_lin_node, + is_linear_op, + is_op, + num_users_of_weight_node, + subgraph, + successors, +) from .quantization_utils import ( cutlass_fp4_scale_to_modelopt_fp4_scale, modelopt_fp4_scale_to_cutlass_fp4_scale, @@ -54,24 +63,290 @@ def _load_hook_remove( state_dict.pop(key, None) -def _update_view_nodes(node: Node) -> None: +def _validate_sharded_shapes( + node: Node, fused_weight_dims: Optional[list] = None, world_size: int = None +) -> None: """ - After sharding weights of the linear node, using column split + Update the shapes of the view nodes and the split node parameters to account for the TP sharding. + 1. After sharding weights of the linear node using column split in attention module (Q, K, V), - the output Y = X @ W^T is [batch, seq, num_heads // TP_size, head_dim] - Some models hardcode the shape of the output to be [batch, seq, num_heads, head_dim] + the output Y = X @ W^T shape is [batch, seq, num_heads // TP_size, head_dim]. + Some models hardcode the shape of the output to [batch, seq, num_heads, head_dim] instead of implicit [batch, seq, -1, head_dim]. Detect such cases and update the shape of the view node accordingly. + 2. If the weights are fused (e.g,. QKV, gate_up, SSM, etc.), the follow-up split node parameters + need to be updated to account for the TP sharding. """ - view_nodes = [n for n in node.users if is_op(n, torch.ops.aten.view)] - for view_node in view_nodes: + + # get the subgraph of this module. Subgraph boundary is the next linear node. + next_lin_node, depth = bfs(node, is_linear_op, include_root=False) + # split nodes can't have "-1" for split size. + nodes_to_validate = successors( + node, + depth=depth, + exclude=lambda n: is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]), + ) + for view_node in nodes_to_validate: + # shard weight tensors for RMS norm and conv1d + if len(view_node.args) < 2: + continue view_shape = view_node.args[1] - if len(view_shape) == 4 and view_shape[2] != -1: + if not isinstance(view_shape, list): + continue + if len(view_shape) >= 3 and isinstance(view_shape[2], int) and view_shape[2] != -1: args = list(view_node.args) - args[1] = [view_shape[0], view_shape[1], -1, view_shape[3]] + args[1] = [view_shape[0], view_shape[1], -1] + view_shape[3:] view_node.args = tuple(args) ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}") + # if fused_weight_dims is provided, we need to update all split sizes + if fused_weight_dims is not None: + assert world_size is not None, "World size is required to update the split node params" + assert len(node.users) == 1, "Fused linear node should have only one user: a split node" + # find all split nodes in the region between this linear node and the next + split_nodes = successors( + node, + depth=depth, + include=lambda n: is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]), + ) + for split_node in split_nodes: + orig_sizes = split_node.args[1] + new_sizes = [orig_sizes[i] // world_size for i in range(len(orig_sizes))] + args = list(split_node.args) + args[1] = new_sizes + split_node.args = tuple(args) + ad_logger.debug(f"\nUpdated split node {split_node} arguments to {split_node.args}") + + +def shard_weight_tensor( + gm: GraphModule, + weight_tensor: torch.Tensor, + param_key: str, + dim: int, + rank: int, + world_size: int, + min_local_shape: int = 1, + fused_weight_dims: Optional[list] = None, + custom_shard_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + requires_grad: bool = False, + update_param: bool = True, +) -> Tuple[torch.Tensor, torch.Size]: + """Shard a weight tensor across ranks and register load hook. + + Args: + gm: GraphModule containing the weight + weight_tensor: The weight tensor to shard + param_key: Parameter key for registering load hook + dim: Dimension to shard along + rank: Current rank + world_size: Total number of ranks + min_local_shape: Minimum local shape constraint (for GQA) + fused_weight_dims: List of dimensions for fused weights + custom_shard_fn: Optional custom function to shard the tensor + requires_grad: Whether the parameter should require gradients + update_param: Whether to update the parameter in the module + + Returns: + Tuple of (sharded_tensor, sharded_shape) + """ + + # Use custom shard function if provided + if custom_shard_fn is not None: + sharded_weight = custom_shard_fn(weight_tensor) + sharded_shape = sharded_weight.shape + # Register load hook with custom function + gm._register_load_state_dict_pre_hook( + partial( + _load_hook, f_split=custom_shard_fn, param_key=param_key, param_shape=sharded_shape + ) + ) + else: + + def split_tensor( + t: torch.Tensor, + d: int = dim, + r: int = rank, + ws: int = world_size, + min_d_shape: int = min_local_shape, + ) -> torch.Tensor: + # The local tensor shape has to be divisible by min_d_shape + max_split_size = t.shape[d] // min_d_shape + if ws > max_split_size: + num_groups = math.ceil(ws / max_split_size) + ad_logger.debug( + f"World size {ws} is greater than the max split size {max_split_size}. " + + f"Splitting tensor to {num_groups} chunks" + ) + return torch.tensor_split(t, max_split_size, dim=d)[r // num_groups] + return torch.tensor_split(t, ws, dim=d)[r] + + # Handle fused weights + if fused_weight_dims is not None: + # Split fused weights, apply TP sharding to each, then concatenate back + sharded_weight = torch.cat( + [split_tensor(w) for w in torch.split(weight_tensor, fused_weight_dims, dim=dim)], + dim=dim, + ) + else: + sharded_weight = split_tensor(weight_tensor) + + sharded_shape = sharded_weight.shape + + # Register load hook + gm._register_load_state_dict_pre_hook( + partial( + _load_hook, f_split=split_tensor, param_key=param_key, param_shape=sharded_shape + ) + ) + + # Update the parameter in the module + if update_param: + modname, _, param_name = param_key.rpartition(".") + submod = gm.get_submodule(modname) + param_new = nn.Parameter(sharded_weight.detach().clone(), requires_grad=requires_grad) + setattr(submod, param_name, param_new) + + return sharded_weight, sharded_shape + + +def get_all_weights_in_subgraph( + sources: list[Node], + sinks: list[Node], +): + """Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks.""" + weight_nodes = subgraph( + sources, sinks, include_boundary_nodes=False, include=lambda n: n.op == "get_attr" + ) + return weight_nodes + + +def _insert_sharded_mamba( + gm: GraphModule, + entry_node: Node, + dim: int, + rank: int, + world_size: int, + add_dist: bool = False, + min_local_shape: int = 1, + weights_to_shard: Optional[list[str]] = None, + weight_shard_dims: Optional[Dict[str, int]] = None, + fused_weight_dims: Optional[Dict[str, list]] = None, + quantization_cb: Optional[ + Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] + ] = None, +) -> None: + """ + To shard Mamba layer, first column-shard the first linear layer: entry_node, + then shard all remaining weight tensors found in the subgraph defined between + entry_node and the next successor linear node. + First, validate if this is indeed a mamba module: within the subgraph, + there should be an torch_ssm node and conv1d node. + + Args: + gm: GraphModule + entry_node: The first linear node of the Mamba layer + dim: Default shard dimension + rank: Current rank + world_size: Total number of ranks + add_dist: Whether to add distribution op after entry_node + min_local_shape: Minimum local shape constraint + weights_to_shard: Optional list of regex patterns to match weight names + weight_shard_dims: Optional dict mapping weight keys to their shard dimensions + fused_weight_dims: Optional dict mapping weight keys to their fused dimension lists + quantization_cb: Optional quantization callback + """ + # Find next linear node to define subgraph boundary + try: + next_lin_node, depth = bfs(entry_node, is_linear_op, include_root=False) + except RuntimeError: + ad_logger.warning("Could not find next linear node after entry_node for Mamba sharding") + return + + # Get subgraph between entry_node and next linear node + subgraph_nodes = subgraph([entry_node], [next_lin_node]) + + # Validate this is a Mamba module by checking for torch_ssm_transform and conv1d + has_ssm = any(is_op(n, torch.ops.auto_deploy.torch_ssm_transform) for n in subgraph_nodes) + has_conv1d = any( + is_op(n, [torch.ops.aten.conv1d, torch.ops.auto_deploy.torch_causal_conv1d]) + for n in subgraph_nodes + ) + + if not (has_ssm and has_conv1d): + ad_logger.warning( + f"Subgraph does not contain both torch_ssm_transform and conv1d nodes. " + f"Skipping Mamba sharding. has_ssm={has_ssm}, has_conv1d={has_conv1d}" + ) + return + + # First, shard the entry_node (the first linear layer) + # Extract entry node's fused_weight_dims by matching weight name against patterns + entry_fused_dims = None + if fused_weight_dims: + entry_weight_key, _ = extract_param_names_from_lin_node(entry_node) + for pattern, dims in fused_weight_dims.items(): + if re.search(pattern, entry_weight_key): + entry_fused_dims = dims + break + + _insert_sharded_matmul( + gm=gm, + node=entry_node, + dim=dim, + rank=rank, + world_size=world_size, + add_dist=add_dist, + min_local_shape=min_local_shape, + fused_weight_dims=entry_fused_dims, + quantization_cb=quantization_cb, + ) + + # Get all weight nodes in the subgraph + weight_nodes = [ + n + for n in get_all_weights_in_subgraph([entry_node], [next_lin_node]) + if "out_proj" not in str(n) + ] + + # Shard remaining weights + for weight_node in weight_nodes: + weight_key = weight_node.target + + # Filter by regex patterns if provided + if weights_to_shard is not None: + if not any(re.search(pattern, weight_key) for pattern in weights_to_shard): + continue + + # Determine shard dimension for this weight + shard_dim = weight_shard_dims.get(weight_key, dim) if weight_shard_dims else dim + + # Get the weight parameter + try: + weight_param = gm.get_parameter(weight_key) + except AttributeError: + ad_logger.debug(f"Could not get parameter for {weight_key}, skipping") + continue + + # Get fused dims for this weight if specified + fused_dims = fused_weight_dims.get(weight_key) if fused_weight_dims else None + + # Shard the weight tensor (also updates the parameter in the module) + _, sharded_shape = shard_weight_tensor( + gm=gm, + weight_tensor=weight_param, + param_key=weight_key, + dim=shard_dim, + rank=rank, + world_size=world_size, + min_local_shape=min_local_shape, + fused_weight_dims=fused_dims, + ) + + ad_logger.debug( + f"Sharded weight {weight_key} on dim {shard_dim}: " + f"{weight_param.shape} -> {sharded_shape}" + ) + def _insert_sharded_matmul( gm: GraphModule, @@ -81,6 +356,7 @@ def _insert_sharded_matmul( world_size: int, add_dist: bool = False, min_local_shape: int = 1, + fused_weight_dims: Optional[list] = None, quantization_cb: Optional[ Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] ] = None, @@ -92,24 +368,6 @@ def _insert_sharded_matmul( assert dim in [0, 1], "Only dim 0 and 1 are supported for sharding" assert add_dist or dim == 0, "For dim=1 sharding, dist_op is required." - def split_tensor( - t: torch.Tensor, - d: int = dim, - r: int = rank, - ws: int = world_size, - min_d_shape: int = min_local_shape, - ) -> torch.Tensor: - # The local tensor shape has to be divisible by min_d_shape - max_split_size = t.shape[d] // min_d_shape - if ws > max_split_size: - num_groups = math.ceil(ws / max_split_size) - ad_logger.debug( - f"World size {ws} is greater than the max split size {max_split_size}. " - + f"Splitting tensor to {num_groups} chunks" - ) - return torch.tensor_split(t, max_split_size, dim=d)[r // num_groups] - return torch.tensor_split(t, ws, dim=d)[r] - num_users = num_users_of_weight_node(node) if num_users > 1 or num_users == 0: ad_logger.warning( @@ -122,36 +380,31 @@ def split_tensor( modname = weight_key.rpartition(".")[0] submod = gm.get_submodule(modname) - def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> torch.Size: - # split or remove it - param_new = ( - None - if remove - else nn.Parameter( - split_tensor(gm.get_parameter(param_key)).detach().clone(), requires_grad=False - ) - ) - - # update the parameter - param_name = param_key.rpartition(".")[-1] - setattr(submod, param_name, param_new) - return torch.Size() if param_new is None else param_new.shape - - # update weight - weight_new_shape = set_new_param(submod, weight_key) - gm._register_load_state_dict_pre_hook( - partial( - _load_hook, f_split=split_tensor, param_key=weight_key, param_shape=weight_new_shape - ) + # Shard weight using the unified function (also updates the parameter) + original_weight = gm.get_parameter(weight_key) + _, weight_new_shape = shard_weight_tensor( + gm=gm, + weight_tensor=original_weight, + param_key=weight_key, + dim=dim, + rank=rank, + world_size=world_size, + min_local_shape=min_local_shape, + fused_weight_dims=fused_weight_dims, ) if bias_key is not None and dim == 0: # update bias for dim 0 --> we can handle it like the weight - bias_new_shape = set_new_param(submod, bias_key) - gm._register_load_state_dict_pre_hook( - partial( - _load_hook, f_split=split_tensor, param_key=bias_key, param_shape=bias_new_shape - ) + original_bias = gm.get_parameter(bias_key) + shard_weight_tensor( + gm=gm, + weight_tensor=original_bias, + param_key=bias_key, + dim=dim, + rank=rank, + world_size=world_size, + min_local_shape=min_local_shape, + fused_weight_dims=None, ) elif bias_key is not None and rank != world_size - 1: # update the bias for dim 1 --> in this case only the last rank gets the bias to avoid @@ -161,7 +414,8 @@ def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> to args[2] = None node.args = tuple(args) gm.graph.erase_node(node_bias) - set_new_param(submod, bias_key, remove=True) + bias_param_name = bias_key.rpartition(".")[-1] + setattr(submod, bias_param_name, None) gm._register_load_state_dict_pre_hook(partial(_load_hook_remove, param_key=bias_key)) if quantization_cb is not None: @@ -178,7 +432,7 @@ def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> to # column shard with no gather: the output is sharded if not add_dist: - _update_view_nodes(node) + _validate_sharded_shapes(node, fused_weight_dims=fused_weight_dims, world_size=world_size) return # figure out the right dist op @@ -242,14 +496,24 @@ def check_and_apply(self, gm: GraphModule, node: Node) -> bool: return True +class LayerType(Enum): + ATTENTION = "attention" + MAMBA = "mamba" + MLP = "mlp" + MOE = "moe" + + class TPShardingInfo(ShardingTransformInfo): """Configuration for TP sharding transformations.""" split_dim: SplitDimension dist_op: Optional[Literal["all_reduce", "all_gather"]] = None min_local_shape: int = 1 + layer_type: LayerType = LayerType.MLP # used for TP sharding of fused weights - fused_weights: Optional[list] + # For MLP/Attention: list of dimensions for fused weights (e.g., [dim1, dim2] for QKV) + # For Mamba: dict mapping weight keys to their fused dimensions + fused_weight_dims: Optional[Union[list, Dict[str, list]]] = None @classmethod def from_node(cls, node: Node, **kwargs) -> "TPShardingInfo": @@ -279,15 +543,30 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool: def apply(self, gm: GraphModule, node: Node) -> None: """Apply TP sharding transformation to the graph module.""" - _insert_sharded_matmul( - gm=gm, - node=node, - dim=self.split_dim.value, - rank=self.rank, - world_size=self.world_size, - add_dist=self.dist_op is not None, - min_local_shape=self.min_local_shape, - ) + if self.layer_type == LayerType.MAMBA: + _insert_sharded_mamba( + gm=gm, + entry_node=node, + dim=self.split_dim.value, + rank=self.rank, + world_size=self.world_size, + add_dist=self.dist_op is not None, + min_local_shape=self.min_local_shape, + fused_weight_dims=self.fused_weight_dims + if isinstance(self.fused_weight_dims, dict) + else None, + ) + else: + _insert_sharded_matmul( + gm=gm, + node=node, + dim=self.split_dim.value, + rank=self.rank, + world_size=self.world_size, + add_dist=self.dist_op is not None, + min_local_shape=self.min_local_shape, + fused_weight_dims=self.fused_weight_dims, + ) class QuantizationShardingMixin(ABC): @@ -538,28 +817,25 @@ def handle_tensor( end_idx: End index for sharding """ - # Define slice function for the sharding - def slice_tensor(t: torch.Tensor) -> torch.Tensor: - return t[start_idx:end_idx] - if tensor_node.op == "get_attr": - # Handle parameter tensor + # Handle parameter tensor using unified shard_weight_tensor weight_key = tensor_node.target - modname, _, param_name = weight_key.rpartition(".") param = gm.get_parameter(weight_key) - # Update the parameter with its shard - param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True) - gm.get_submodule(modname).register_parameter(param_name, param_new) - - # Register load state dict hook - gm._register_load_state_dict_pre_hook( - partial( - _load_hook, - f_split=slice_tensor, - param_key=weight_key, - param_shape=param_new.shape, - ) + # Define slice function for the sharding + def slice_tensor(t: torch.Tensor) -> torch.Tensor: + return t[start_idx:end_idx] + + # Use shard_weight_tensor with custom shard function (also updates the parameter) + shard_weight_tensor( + gm=gm, + weight_tensor=param, + param_key=weight_key, + dim=0, # BMM slices along batch dimension + rank=self.rank, + world_size=self.world_size, + custom_shard_fn=slice_tensor, + requires_grad=True, # BMM parameters require gradients ) else: # Handle dynamic tensor @@ -911,27 +1187,3 @@ def validate_config(self) -> bool: def get_predefined_config(self) -> Dict[str, Any]: return self.predefined_config - - -def _append_simple_shard( - nodes_linear: Dict[Node, List[Node]], - rank: int, - world_size: int, - sharding_config: ShardingConfig, -) -> None: - # for every linear node: - # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) - tp_shards: List[TPShardingInfo] = [] - for node_group in nodes_linear.values(): - for n in node_group: - tp_shards.append( - TPShardingInfo( - target_node=n.name, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, - ) - ) - sharding_config.tp_transforms.extend(tp_shards) From 609dca82528d689008b46512d8084f524e6eb675 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Tue, 21 Oct 2025 00:40:45 -0700 Subject: [PATCH 04/12] debugging Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../auto_deploy/utils/sharding_utils.py | 135 +++++++++++++++++- 1 file changed, 133 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index f81ab0e488c..173ac210dbb 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -29,6 +29,69 @@ modelopt_fp4_scale_to_cutlass_fp4_scale, ) +DEBUG = True + + +def _initialize_debug_tensor(t: torch.Tensor, dim: int) -> torch.Tensor: + """Initialize tensor along dim with sequential indices for debugging.""" + # Create index tensor: t[:, i, :] = i for all i along dimension dim + shape = list(t.shape) + indices = torch.arange(shape[dim], dtype=t.dtype, device=t.device) + # Reshape indices to broadcast correctly + view_shape = [1] * len(shape) + view_shape[dim] = shape[dim] + indices = indices.view(view_shape) + # Broadcast to full shape + return indices.expand(shape).clone() + + +def _validate_sharded_indices( + sharded_tensor: torch.Tensor, + dim: int, + rank: int, + world_size: int, + fused_weight_dims: Optional[list] = None, + param_key: str = "", +): + """Validate that sharded tensor contains expected indices.""" + if not DEBUG: + return + + # Get unique values from the sharded tensor + unique_vals = torch.unique(sharded_tensor).cpu().numpy().astype(int) + + if fused_weight_dims is None: + # Non-fused: expect contiguous chunk + total_size = sharded_tensor.shape[dim] * world_size + chunk_size = total_size // world_size + expected_start = rank * chunk_size + expected_end = expected_start + chunk_size + expected = set(range(expected_start, expected_end)) + else: + # Fused: expect sharded chunks from each fused component + expected = set() + offset = 0 + for fused_dim in fused_weight_dims: + chunk_size = fused_dim // world_size + chunk_start = offset + rank * chunk_size + chunk_end = chunk_start + chunk_size + expected.update(range(chunk_start, chunk_end)) + offset += fused_dim + + actual = set(unique_vals) + + ad_logger.info(f"DEBUG [{param_key}] Rank {rank}: Expected indices: {sorted(expected)}") + ad_logger.info(f"DEBUG [{param_key}] Rank {rank}: Actual indices: {sorted(actual)}") + + assert actual == expected, ( + f"Rank {rank} sharding mismatch for {param_key}!\n" + f"Expected: {sorted(expected)}\n" + f"Actual: {sorted(actual)}\n" + f"Missing: {sorted(expected - actual)}\n" + f"Extra: {sorted(actual - expected)}" + ) + ad_logger.info(f"DEBUG [{param_key}] Rank {rank}: ✓ Validation passed") + def _load_hook( state_dict, @@ -37,6 +100,10 @@ def _load_hook( f_split: Callable[[torch.Tensor, int], torch.Tensor], param_key: str, param_shape: torch.Size, + dim: int, + rank: int, + world_size: int, + fused_weight_dims: Optional[list] = None, ): # TODO: we need to support loading either a sharded or unsharded checkpoint. # Otherwise, basic workflows like @@ -48,7 +115,25 @@ def _load_hook( if key not in state_dict: return p_to_load = state_dict[key] + + # Debug: Initialize with sequential indices + if DEBUG and param_shape != p_to_load.shape: + ad_logger.info(f"DEBUG: Initializing tensor '{key}' with sequential indices") + p_to_load = _initialize_debug_tensor(p_to_load, dim) + p_to_load = p_to_load if param_shape == p_to_load.shape else f_split(p_to_load) + + # Debug: Validate sharded indices + if DEBUG and param_shape != state_dict[key].shape: + _validate_sharded_indices( + p_to_load, + dim=dim, + rank=rank, + world_size=world_size, + fused_weight_dims=fused_weight_dims, + param_key=key, + ) + state_dict[key] = p_to_load @@ -150,6 +235,13 @@ def shard_weight_tensor( Tuple of (sharded_tensor, sharded_shape) """ + # Debug: Initialize tensor with sequential indices + if DEBUG: + weight_tensor = _initialize_debug_tensor(weight_tensor, dim) + ad_logger.info( + f"DEBUG: Initialized weight_tensor for '{param_key}' with sequential indices" + ) + # Use custom shard function if provided if custom_shard_fn is not None: sharded_weight = custom_shard_fn(weight_tensor) @@ -157,7 +249,14 @@ def shard_weight_tensor( # Register load hook with custom function gm._register_load_state_dict_pre_hook( partial( - _load_hook, f_split=custom_shard_fn, param_key=param_key, param_shape=sharded_shape + _load_hook, + f_split=custom_shard_fn, + param_key=param_key, + param_shape=sharded_shape, + dim=dim, + rank=rank, + world_size=world_size, + fused_weight_dims=fused_weight_dims, ) ) else: @@ -187,15 +286,47 @@ def split_tensor( [split_tensor(w) for w in torch.split(weight_tensor, fused_weight_dims, dim=dim)], dim=dim, ) + + # Create a function that applies the same logic for loading + def split_fused_tensor( + t: torch.Tensor, + fused_dims: list = fused_weight_dims, + d: int = dim, + ) -> torch.Tensor: + return torch.cat( + [split_tensor(w) for w in torch.split(t, fused_dims, dim=d)], + dim=d, + ) + + f_split = split_fused_tensor else: sharded_weight = split_tensor(weight_tensor) + f_split = split_tensor sharded_shape = sharded_weight.shape + # Debug: Validate sharded indices + if DEBUG: + _validate_sharded_indices( + sharded_weight, + dim=dim, + rank=rank, + world_size=world_size, + fused_weight_dims=fused_weight_dims, + param_key=param_key, + ) + # Register load hook gm._register_load_state_dict_pre_hook( partial( - _load_hook, f_split=split_tensor, param_key=param_key, param_shape=sharded_shape + _load_hook, + f_split=f_split, + param_key=param_key, + param_shape=sharded_shape, + dim=dim, + rank=rank, + world_size=world_size, + fused_weight_dims=fused_weight_dims, ) ) From ad7364b823b12378a414ea49eb26fc83e9aa5ba9 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Tue, 21 Oct 2025 01:14:54 -0700 Subject: [PATCH 05/12] wip debugging Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index da9584c1f08..11dc80ac28d 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -239,6 +239,12 @@ def filtered_nodes( for node in nodes: if target(node): yield node + elif isinstance(target, Iterable) and all(isinstance(t, Callable) for t in target): + for node in nodes: + for t in target: + if t(node): + yield node + break else: # Handle the case where target or ops contains operations operations = ops if ops is not None else target From 3c93381c64945a790adef5b69a329d103db62d8d Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Tue, 21 Oct 2025 12:09:17 -0700 Subject: [PATCH 06/12] working nemotron sharding Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../auto_deploy/models/patches/nemotron_h.py | 4 +- .../auto_deploy/transform/library/sharding.py | 43 ++-- .../auto_deploy/utils/sharding_utils.py | 210 ++++++------------ 3 files changed, 91 insertions(+), 166 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py index b46d48848f4..aefb66a5e65 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py @@ -150,11 +150,11 @@ def get_model_from_config_patched(config, **kwargs): def _set_sharding_config_patched(self, *args, **kwargs): self._sharding_config["head_dim"] = 128 self._sharding_config["tp_plan"] = { - "in_proj": 'mamba("fused_weight_dims" = {"in_proj": [8192,8192,1024,1024,128], "conv1d": [8192, 1024, 1024]})', + "in_proj": "mamba", "out_proj": "rowwise", "up_proj": "colwise", "down_proj": "rowwise", - # "*": "gather", + "*": "gather", } diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index f640f09e159..6ab1d024f40 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -167,6 +167,7 @@ def _apply( shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: local_rank, world_size = shared_config.local_rank, shared_config.world_size + # world_size = 2 if world_size < 2: ad_logger.info("Skipping sharding for single device") @@ -175,42 +176,35 @@ def _apply( ) assert isinstance(gm, GraphModule), "Expecting GraphModule" - shared_config.sharding_config.rank = local_rank - shared_config.sharding_config.world_size = world_size - shared_config.sharding_config.predefined_config = ( - factory.get_sharding_config() if factory else {} - ) - shared_config.sharding_config.factory_source = ( - shared_config.sharding_config.predefined_config.get( - "source", ShardingConfigSource.UNKNOWN - ) + sharding_config = shared_config.sharding_config + sharding_config.rank = local_rank + sharding_config.world_size = world_size + sharding_config.predefined_config = factory.get_sharding_config() if factory else {} + sharding_config.factory_source = ( + sharding_config.predefined_config.get("source", ShardingConfigSource.UNKNOWN) if factory else ShardingConfigSource.UNKNOWN ) - shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only - shared_config.sharding_config.support_partial_config = self.config.support_partial_config - shared_config.sharding_config.sharding_dims = self.config.sharding_dims + sharding_config.simple_shard_only = self.config.simple_shard_only + sharding_config.support_partial_config = self.config.support_partial_config + sharding_config.sharding_dims = self.config.sharding_dims - shared_config.sharding_config.use_sharding_from_factory = ( - self.config.use_sharding_from_factory - ) + sharding_config.use_sharding_from_factory = self.config.use_sharding_from_factory - sharding_config = shared_config.sharding_config sharding_config.validate_config() + # sharding_config.predefined_config = predefined_config if ( - shared_config.sharding_config.use_sharding_from_factory - and len(shared_config.sharding_config.get_predefined_config()) > 0 + sharding_config.use_sharding_from_factory + and len(sharding_config.get_predefined_config()) > 0 ): ad_logger.info("Applying sharding from config") factory_info = detect_sharding_from_factory_config(gm, sharding_config) return gm, factory_info - ad_logger.info( - f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}" - ) + ad_logger.info(f"Running autodeploy sharding heuristics: {sharding_config.sharding_dims}") # run TP sharding across ranks - if "tp" in shared_config.sharding_config.sharding_dims: + if "tp" in sharding_config.sharding_dims: tp_info = detect_column_row_shard(gm, sharding_config) else: tp_info = TransformInfo( @@ -218,7 +212,7 @@ def _apply( ) # run EP sharding across ranks - if "ep" in shared_config.sharding_config.sharding_dims: + if "ep" in sharding_config.sharding_dims: ep_info = detect_ep_shard(gm, sharding_config) else: ep_info = TransformInfo( @@ -226,7 +220,7 @@ def _apply( ) # run BMM sharding across ranks - if "bmm" in shared_config.sharding_config.sharding_dims: + if "bmm" in sharding_config.sharding_dims: dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config) else: dp_bmm_info = TransformInfo( @@ -345,6 +339,7 @@ def detect_sharding_from_factory_config( min_local_shape=min_local_shape, ) ) + num_row_col_shards += 1 elif config == "rowwise": sharding_config.tp_transforms.append( TPShardingInfo.from_node( diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 173ac210dbb..03a225ce51a 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -22,76 +22,12 @@ is_op, num_users_of_weight_node, subgraph, - successors, ) from .quantization_utils import ( cutlass_fp4_scale_to_modelopt_fp4_scale, modelopt_fp4_scale_to_cutlass_fp4_scale, ) -DEBUG = True - - -def _initialize_debug_tensor(t: torch.Tensor, dim: int) -> torch.Tensor: - """Initialize tensor along dim with sequential indices for debugging.""" - # Create index tensor: t[:, i, :] = i for all i along dimension dim - shape = list(t.shape) - indices = torch.arange(shape[dim], dtype=t.dtype, device=t.device) - # Reshape indices to broadcast correctly - view_shape = [1] * len(shape) - view_shape[dim] = shape[dim] - indices = indices.view(view_shape) - # Broadcast to full shape - return indices.expand(shape).clone() - - -def _validate_sharded_indices( - sharded_tensor: torch.Tensor, - dim: int, - rank: int, - world_size: int, - fused_weight_dims: Optional[list] = None, - param_key: str = "", -): - """Validate that sharded tensor contains expected indices.""" - if not DEBUG: - return - - # Get unique values from the sharded tensor - unique_vals = torch.unique(sharded_tensor).cpu().numpy().astype(int) - - if fused_weight_dims is None: - # Non-fused: expect contiguous chunk - total_size = sharded_tensor.shape[dim] * world_size - chunk_size = total_size // world_size - expected_start = rank * chunk_size - expected_end = expected_start + chunk_size - expected = set(range(expected_start, expected_end)) - else: - # Fused: expect sharded chunks from each fused component - expected = set() - offset = 0 - for fused_dim in fused_weight_dims: - chunk_size = fused_dim // world_size - chunk_start = offset + rank * chunk_size - chunk_end = chunk_start + chunk_size - expected.update(range(chunk_start, chunk_end)) - offset += fused_dim - - actual = set(unique_vals) - - ad_logger.info(f"DEBUG [{param_key}] Rank {rank}: Expected indices: {sorted(expected)}") - ad_logger.info(f"DEBUG [{param_key}] Rank {rank}: Actual indices: {sorted(actual)}") - - assert actual == expected, ( - f"Rank {rank} sharding mismatch for {param_key}!\n" - f"Expected: {sorted(expected)}\n" - f"Actual: {sorted(actual)}\n" - f"Missing: {sorted(expected - actual)}\n" - f"Extra: {sorted(actual - expected)}" - ) - ad_logger.info(f"DEBUG [{param_key}] Rank {rank}: ✓ Validation passed") - def _load_hook( state_dict, @@ -100,10 +36,6 @@ def _load_hook( f_split: Callable[[torch.Tensor, int], torch.Tensor], param_key: str, param_shape: torch.Size, - dim: int, - rank: int, - world_size: int, - fused_weight_dims: Optional[list] = None, ): # TODO: we need to support loading either a sharded or unsharded checkpoint. # Otherwise, basic workflows like @@ -116,24 +48,8 @@ def _load_hook( return p_to_load = state_dict[key] - # Debug: Initialize with sequential indices - if DEBUG and param_shape != p_to_load.shape: - ad_logger.info(f"DEBUG: Initializing tensor '{key}' with sequential indices") - p_to_load = _initialize_debug_tensor(p_to_load, dim) - p_to_load = p_to_load if param_shape == p_to_load.shape else f_split(p_to_load) - # Debug: Validate sharded indices - if DEBUG and param_shape != state_dict[key].shape: - _validate_sharded_indices( - p_to_load, - dim=dim, - rank=rank, - world_size=world_size, - fused_weight_dims=fused_weight_dims, - param_key=key, - ) - state_dict[key] = p_to_load @@ -165,22 +81,21 @@ def _validate_sharded_shapes( # get the subgraph of this module. Subgraph boundary is the next linear node. next_lin_node, depth = bfs(node, is_linear_op, include_root=False) - # split nodes can't have "-1" for split size. - nodes_to_validate = successors( - node, - depth=depth, - exclude=lambda n: is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]), + nodes_to_validate = subgraph( + [node], + [next_lin_node], + include=lambda n: is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]), ) for view_node in nodes_to_validate: - # shard weight tensors for RMS norm and conv1d if len(view_node.args) < 2: continue - view_shape = view_node.args[1] + view_shape = list(view_node.args[1]) if not isinstance(view_shape, list): continue if len(view_shape) >= 3 and isinstance(view_shape[2], int) and view_shape[2] != -1: args = list(view_node.args) - args[1] = [view_shape[0], view_shape[1], -1] + view_shape[3:] + view_shape[2] = view_shape[2] // world_size + args[1] = tuple(view_shape) view_node.args = tuple(args) ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}") @@ -189,9 +104,9 @@ def _validate_sharded_shapes( assert world_size is not None, "World size is required to update the split node params" assert len(node.users) == 1, "Fused linear node should have only one user: a split node" # find all split nodes in the region between this linear node and the next - split_nodes = successors( - node, - depth=depth, + split_nodes = subgraph( + [node], + [next_lin_node], include=lambda n: is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]), ) for split_node in split_nodes: @@ -235,13 +150,6 @@ def shard_weight_tensor( Tuple of (sharded_tensor, sharded_shape) """ - # Debug: Initialize tensor with sequential indices - if DEBUG: - weight_tensor = _initialize_debug_tensor(weight_tensor, dim) - ad_logger.info( - f"DEBUG: Initialized weight_tensor for '{param_key}' with sequential indices" - ) - # Use custom shard function if provided if custom_shard_fn is not None: sharded_weight = custom_shard_fn(weight_tensor) @@ -253,12 +161,9 @@ def shard_weight_tensor( f_split=custom_shard_fn, param_key=param_key, param_shape=sharded_shape, - dim=dim, - rank=rank, - world_size=world_size, - fused_weight_dims=fused_weight_dims, ) ) + else: def split_tensor( @@ -305,17 +210,6 @@ def split_fused_tensor( sharded_shape = sharded_weight.shape - # Debug: Validate sharded indices - if DEBUG: - _validate_sharded_indices( - sharded_weight, - dim=dim, - rank=rank, - world_size=world_size, - fused_weight_dims=fused_weight_dims, - param_key=param_key, - ) - # Register load hook gm._register_load_state_dict_pre_hook( partial( @@ -323,10 +217,6 @@ def split_fused_tensor( f_split=f_split, param_key=param_key, param_shape=sharded_shape, - dim=dim, - rank=rank, - world_size=world_size, - fused_weight_dims=fused_weight_dims, ) ) @@ -365,7 +255,7 @@ def _insert_sharded_mamba( quantization_cb: Optional[ Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] ] = None, -) -> None: +) -> bool: """ To shard Mamba layer, first column-shard the first linear layer: entry_node, then shard all remaining weight tensors found in the subgraph defined between @@ -391,24 +281,63 @@ def _insert_sharded_mamba( next_lin_node, depth = bfs(entry_node, is_linear_op, include_root=False) except RuntimeError: ad_logger.warning("Could not find next linear node after entry_node for Mamba sharding") - return + return False # Get subgraph between entry_node and next linear node subgraph_nodes = subgraph([entry_node], [next_lin_node]) - # Validate this is a Mamba module by checking for torch_ssm_transform and conv1d - has_ssm = any(is_op(n, torch.ops.auto_deploy.torch_ssm_transform) for n in subgraph_nodes) - has_conv1d = any( - is_op(n, [torch.ops.aten.conv1d, torch.ops.auto_deploy.torch_causal_conv1d]) + ############################################################## + ########## validate if this is a valid Mamba module ########## + ############################################################## + # has_ssm = any(is_op(n, torch.ops.auto_deploy.mamba.torch_ssm_transform) for n in subgraph_nodes) + has_ssm = True + conv1d_nodes = [ + n for n in subgraph_nodes - ) - - if not (has_ssm and has_conv1d): + if is_op(n, [torch.ops.aten.conv1d, torch.ops.auto_deploy.torch_causal_conv1d]) + ] + if len(conv1d_nodes) != 1 or not has_ssm: ad_logger.warning( - f"Subgraph does not contain both torch_ssm_transform and conv1d nodes. " - f"Skipping Mamba sharding. has_ssm={has_ssm}, has_conv1d={has_conv1d}" + f"Subgraph does not contain exactly one conv1d node and torch_ssm_transform. " + f"Skipping Mamba sharding. conv1d_nodes={conv1d_nodes}, has_ssm={has_ssm}" ) - return + return False + + ############################################################## + ########## infer split sizes for in_proj and conv1d ########## + ############################################################## + # in_proj and conv1d are most likely fused, followed up by split nodes. Infer split sizes: + if fused_weight_dims is None: + split_nodes = [ + n + for n in subgraph_nodes + if is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]) + ] + if len(split_nodes) != 2: + ad_logger.warning( + f"Subgraph does not contain exactly two split nodes. " + f"Skipping Mamba sharding. split_nodes={split_nodes}" + ) + return False + split_sizes_1 = split_nodes[0].args[1] + split_sizes_2 = split_nodes[1].args[1] + if split_sizes_1[1] != sum(split_sizes_2): + ad_logger.warning( + f"Split nodes have different sizes. " + f"Skipping Mamba sharding. split_sizes_1={split_sizes_1}, split_sizes_2={split_sizes_2}" + ) + return False + fused_weight_dims = { + "in_proj": split_sizes_1[0:1] + split_sizes_2 + split_sizes_1[2:], + "conv1d": split_sizes_2, + } + + conv1d_node = conv1d_nodes[0] + # conv1d_node last argument is the number of output channels. + # This one is also sharded, so we need to update this parameter + conv_args = list(conv1d_node.args) + conv_args[-1] = conv1d_node.args[-1] // world_size + conv1d_node.args = tuple(conv_args) # First, shard the entry_node (the first linear layer) # Extract entry node's fused_weight_dims by matching weight name against patterns @@ -432,20 +361,20 @@ def _insert_sharded_mamba( quantization_cb=quantization_cb, ) - # Get all weight nodes in the subgraph + # Get all weight nodes in the subgraph except for out_proj weight_nodes = [ n for n in get_all_weights_in_subgraph([entry_node], [next_lin_node]) if "out_proj" not in str(n) ] - # Shard remaining weights + # Shard remaining weights, such as conv1d or RMSNorm for weight_node in weight_nodes: weight_key = weight_node.target # Filter by regex patterns if provided if weights_to_shard is not None: - if not any(re.search(pattern, weight_key) for pattern in weights_to_shard): + if not any(pattern in weight_key for pattern in weights_to_shard): continue # Determine shard dimension for this weight @@ -459,7 +388,11 @@ def _insert_sharded_mamba( continue # Get fused dims for this weight if specified - fused_dims = fused_weight_dims.get(weight_key) if fused_weight_dims else None + fused_dims = None + for k, v in fused_weight_dims.items(): + if k in weight_key: + fused_dims = v + break # Shard the weight tensor (also updates the parameter in the module) _, sharded_shape = shard_weight_tensor( @@ -757,9 +690,6 @@ def quantization_cb( self.shard_load_hook, weight_name=weight_key, weight_shape=weight_new_shape, - dim=dim, - rank=rank, - world_size=world_size, ) ) From bffef4b788d02e09231bf4da90669f8b8739b602 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Tue, 21 Oct 2025 12:31:55 -0700 Subject: [PATCH 07/12] removed redundant files Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- modelling_nemotron_h.py | 1899 --------------------------------------- simple_nemotron.py | 731 --------------- 2 files changed, 2630 deletions(-) delete mode 100644 modelling_nemotron_h.py delete mode 100644 simple_nemotron.py diff --git a/modelling_nemotron_h.py b/modelling_nemotron_h.py deleted file mode 100644 index 2c51f60b93f..00000000000 --- a/modelling_nemotron_h.py +++ /dev/null @@ -1,1899 +0,0 @@ -# coding=utf-8 -# Copyright 2024 HuggingFace Inc. team. -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch NemotronH model.""" - -import math -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss -from transformers.activations import ACT2FN -from transformers.cache_utils import \ - DynamicCache # we need __iter__ and __len__ of pkv -from transformers.generation import GenerationMixin -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import (ModelOutput, add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, logging) -from transformers.utils.import_utils import ( - is_causal_conv1d_available, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_mamba_2_ssm_available) - -from .configuration_nemotron_h import NemotronHConfig - -logger = logging.get_logger(__name__) - -# Copied from transformers.models.mamba.modeling_mamba2.modeling_mamba2.py with MAMBA2->NEMOTRONH,Mamba2->NemotronH -# For Mamba2 components Mamba2->NemotronHMamba2 -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import \ - selective_state_update - from mamba_ssm.ops.triton.ssd_combined import ( - mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined) -else: - mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None - -try: - #from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated - from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn -except ImportError: - raise ImportError( - "mamba-ssm is required by the Mamba model but cannot be imported") - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - -if is_flash_attn_2_available(): - from transformers.modeling_flash_attention_utils import \ - _flash_attention_forward - -is_fast_path_available = all(( - selective_state_update, - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - causal_conv1d_fn, - causal_conv1d_update, -)) - -_CHECKPOINT_FOR_DOC = "nvidia/Nemotron-H-56B-Base-8K" -_CONFIG_FOR_DOC = "NemotronHConfig" - -# Helper methods for segment sum computation - - -def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): - """ - Padding x tensor with `pad_size` on the seq_len dim (dim=1) - - Assumes that we only have tensors of either size 4 or 3 - """ - pad_shape = (0, 0, 0, 0, 0, pad_size, 0, - 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, - 0) - - return torch.nn.functional.pad(input_tensor, - pad_shape, - mode="constant", - value=0) - - -def reshape_into_chunks(input_tensor, pad_size, chunk_size): - """ - Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and - simultaneously splitting it into chunk sequences. - - Assumes that we only have tensors of either size 4 or 3 - """ - # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] - input_tensor = pad_tensor_by_size(input_tensor, pad_size) - - if len(input_tensor.shape) == 3: - # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] - return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, - input_tensor.shape[2]) - else: - # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] - return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, - input_tensor.shape[2], - input_tensor.shape[3]) - - -def segment_sum(input_tensor): - """ - More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. - """ - chunk_size = input_tensor.size(-1) - # 1. expand input tensor to have an additional dimension and repeat along that dimension - # [..., chunk_size] -> [..., chunk_size, chunk_size] - input_tensor = input_tensor[..., None].expand(*input_tensor.size(), - chunk_size) - # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag - mask = torch.tril(torch.ones(chunk_size, - chunk_size, - device=input_tensor.device, - dtype=torch.bool), - diagonal=-1) - input_tensor = input_tensor.masked_fill(~mask, 0) - # 3. compute actual cumsum - tensor_segsum = torch.cumsum(input_tensor, dim=-2) - - # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) - mask = torch.tril(torch.ones(chunk_size, - chunk_size, - device=input_tensor.device, - dtype=torch.bool), - diagonal=0) - tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) - return tensor_segsum - - -def apply_mask_to_padding_states(hidden_states, attention_mask): - """ - Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 - """ - if attention_mask is not None and attention_mask.shape[ - 1] > 1 and attention_mask.shape[0] > 1: - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - - return hidden_states - - -# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py -class HybridMambaAttentionDynamicCache(DynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__(self, config, batch_size, dtype=torch.float16, device=None): - super().__init__() - self.dtype = dtype - self.hybrid_override_pattern = config.hybrid_override_pattern - self.has_previous_state = False # only used by mamba - intermediate_size = config.mamba_num_heads * config.mamba_head_dim - ssm_state_size = config.ssm_state_size - conv_kernel_size = config.conv_kernel - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - for i in range(config.num_hidden_layers): - if self.hybrid_override_pattern[i] == "M": - # Mamba layer - self.conv_states += [ - torch.zeros(batch_size, - intermediate_size, - conv_kernel_size, - device=device, - dtype=dtype) - ] - self.ssm_states += [ - torch.zeros(batch_size, - intermediate_size, - ssm_state_size, - device=device, - dtype=dtype) - ] - else: - # Attention or MLP layer - self.conv_states += [ - torch.tensor([[]] * batch_size, device=device) - ] - self.ssm_states += [ - torch.tensor([[]] * batch_size, device=device) - ] - self.transformer_layers.append(i) - - self.key_cache = [ - torch.tensor([[]] * batch_size, device=device) - for _ in range(config.num_hidden_layers) - ] - self.value_cache = [ - torch.tensor([[]] * batch_size, device=device) - for _ in range(config.num_hidden_layers) - ] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat( - [self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( - 0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[ - layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[ - layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[ - layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[ - 0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def to_legacy_cache( - self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - raise NotImplementedError( - "HybridMambaAttentionDynamicCache does not have a legacy cache equivalent." - ) - - @classmethod - def from_legacy_cache( - cls, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "DynamicCache": - raise NotImplementedError( - "HybridMambaAttentionDynamicCache does not have a legacy cache equivalent." - ) - - # Copied from modeling_mamba2.py - def update_conv_state(self, - layer_idx: int, - new_conv_state: torch.Tensor, - cache_init: bool = False) -> torch.Tensor: - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to( - self.conv_states.device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll( - shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to( - self.conv_states.device) - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) - return self.ssm_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - -class MambaRMSNormGated(torch.nn.Module): - - def __init__(self, hidden_size, group_size, eps=1e-5): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - self.group_size = group_size - - # jan28b version - def forward(self, hidden_states, gate=None): - return rmsnorm_fn( - x=hidden_states, - weight=self.weight, - bias=None, # No bias - z=gate, - eps=self.variance_epsilon, - group_size=self.group_size, - norm_before_gate=False) - - -class NemotronHMamba2Mixer(nn.Module): - """ - Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. - A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) - ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, - and is why Mamba is called **selective** state spaces) - """ - - def __init__(self, config: NemotronHConfig, layer_idx: int): - super().__init__() - self.num_heads = config.mamba_num_heads - self.hidden_size = config.hidden_size - self.ssm_state_size = config.ssm_state_size - self.conv_kernel_size = config.conv_kernel - self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim - self.layer_idx = layer_idx - self.use_conv_bias = config.use_conv_bias - self.activation = config.mamba_hidden_act - self.act = ACT2FN[config.mamba_hidden_act] - - self.layer_norm_epsilon = config.layer_norm_epsilon - - self.n_groups = config.n_groups - self.head_dim = config.mamba_head_dim - self.chunk_size = config.chunk_size - - self.time_step_limit = config.time_step_limit - self.time_step_min = config.time_step_min - self.time_step_max = config.time_step_max - - self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size - self.conv1d = nn.Conv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, - bias=config.use_conv_bias, - kernel_size=config.conv_kernel, - groups=self.conv_dim, - padding=config.conv_kernel - 1, - ) - - # projection of the input hidden states - projection_size = self.intermediate_size + self.conv_dim + self.num_heads - self.in_proj = nn.Linear( - self.hidden_size, - projection_size, - bias=config.use_bias, - ) - # selective projection used to make dt, B and C input dependent - - # time step projection (discretization) - # instantiate once and copy inv_dt in init_weights of PretrainedModel - self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) - - # S4D real initialization. These are not discretized! - # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded - A = torch.arange(1, self.num_heads + 1) - self.A_log = nn.Parameter(torch.log(A)) - self.A_log._no_weight_decay = True - self.norm = MambaRMSNormGated(self.intermediate_size, - eps=self.layer_norm_epsilon, - group_size=self.intermediate_size // - self.n_groups) - self.D = nn.Parameter(torch.ones(self.num_heads)) - self.D._no_weight_decay = True - - self.out_proj = nn.Linear(self.intermediate_size, - self.hidden_size, - bias=config.use_bias) - self.use_bias = config.use_bias - - if not is_fast_path_available: - logger.warning_once( - "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" - " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d") - - def cuda_kernels_forward( - self, - hidden_states: torch.Tensor, - cache_params: Optional[HybridMambaAttentionDynamicCache] = None, - cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ): - # 1. Gated MLP's linear projection - hidden_states = apply_mask_to_padding_states(hidden_states, - attention_mask) - projected_states = self.in_proj(hidden_states) - - # Set up dimensions for reshapes later - batch_size, seq_len, _ = hidden_states.shape - groups_time_state_size = self.n_groups * self.ssm_state_size - d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - - 2 * self.n_groups * self.ssm_state_size - self.num_heads) // 2 - - # Single step calculations via cache - if cache_params is not None and cache_position is not None and cache_position[ - 0] > 0: - _, _, gate, hidden_states_B_C, dt = projected_states.squeeze( - 1).split([ - d_mlp, d_mlp, self.intermediate_size, self.conv_dim, - self.num_heads - ], - dim=-1) - - # 2. Convolution sequence transformation - hidden_states_B_C = causal_conv1d_update( - hidden_states_B_C, - cache_params.conv_states[self.layer_idx], - self.conv1d.weight.squeeze(1), - self.conv1d.bias, - self.activation, - ) - - hidden_states, B, C = torch.split( - hidden_states_B_C, - [ - self.intermediate_size, groups_time_state_size, - groups_time_state_size - ], - dim=-1, - ) - - # 3. SSM transformation - A = -torch.exp(self.A_log.float()) # (nheads,) - A = A[:, None, - ...][:, :, - None].expand(-1, self.head_dim, - self.ssm_state_size).to(dtype=torch.float32) - dt = dt[:, :, None].expand(-1, -1, self.head_dim) - dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) - D = self.D[:, None, ...].expand(-1, self.head_dim) - B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) - C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) - hidden_states_reshaped = hidden_states.view(batch_size, - self.num_heads, - self.head_dim) - hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], - hidden_states_reshaped, - dt, - A, - B, - C, - D, - z=None, - dt_bias=dt_bias, - dt_softplus=True, - ) - hidden_states = hidden_states.view(batch_size, - self.num_heads * self.head_dim) - hidden_states = self.norm(hidden_states, gate) - - # 4. Final linear projection - out = self.out_proj(hidden_states)[:, None, ...] - - # Fused calculations or step by step if no initialized cache is found - else: - A = -torch.exp(self.A_log.float() - ) # (num_heads) or (intermediate_size, state_size) - dt_limit_kwargs = {} if self.time_step_limit == ( - 0.0, float("inf")) else { - "dt_limit": self.time_step_limit - } - - # 2-4. Fused kernel for conv1d, SSM, and the final projection - if self.training and cache_params is None: - out = mamba_split_conv1d_scan_combined( - projected_states, - self.conv1d.weight.squeeze(1), - self.conv1d.bias, - self.dt_bias, - A, - D=self.D, - chunk_size=self.chunk_size, - seq_idx=None, # was seq_idx - activation=self.activation, - rmsnorm_weight=self.norm.weight, - rmsnorm_eps=self.norm.variance_epsilon, - outproj_weight=self.out_proj.weight, - outproj_bias=self.out_proj.bias, - headdim=self.head_dim, - ngroups=self.n_groups, - norm_before_gate=False, - return_final_states=False, - **dt_limit_kwargs, - ) - - else: - _, _, gate, hidden_states_B_C, dt = projected_states.split( - [ - d_mlp, d_mlp, self.intermediate_size, self.conv_dim, - self.num_heads - ], - dim=-1) - - # 2. Convolution sequence transformation - # Init cache - if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose( - 1, 2) - conv_states = nn.functional.pad( - hidden_states_B_C_transposed, - (cache_params.conv_kernel_size - - hidden_states_B_C_transposed.shape[-1], 0), - ) - cache_params.update_conv_state(layer_idx=self.layer_idx, - new_conv_state=conv_states, - cache_init=True) - - if self.activation not in ["silu", "swish"]: - hidden_states_B_C = self.act( - self.conv1d(hidden_states_B_C.transpose( - 1, 2))[..., :seq_len].transpose(1, 2)) - else: - hidden_states_B_C = causal_conv1d_fn( - x=hidden_states_B_C.transpose(1, 2), - weight=self.conv1d.weight.squeeze(1), - bias=self.conv1d.bias, - activation=self.activation, - ).transpose(1, 2) - hidden_states_B_C = apply_mask_to_padding_states( - hidden_states_B_C, attention_mask) - hidden_states, B, C = torch.split( - hidden_states_B_C, - [ - self.intermediate_size, groups_time_state_size, - groups_time_state_size - ], - dim=-1, - ) - - # 3. SSM transformation - scan_output, ssm_state = mamba_chunk_scan_combined( - hidden_states.view(batch_size, seq_len, -1, self.head_dim), - dt, - A, - B.view(batch_size, seq_len, self.n_groups, -1), - C.view(batch_size, seq_len, self.n_groups, -1), - chunk_size=self.chunk_size, - D=self.D, - z=None, - seq_idx=None, - return_final_states=True, - dt_bias=self.dt_bias, - dt_softplus=True, - **dt_limit_kwargs, - ) - - # Init cache - if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(layer_idx=self.layer_idx, - new_ssm_state=ssm_state) - - scan_output = scan_output.view(batch_size, seq_len, -1) - - # Multiply "gate" branch and apply extra normalization layer - scan_output = self.norm(scan_output, gate) - - # 4. Final linear projection - out = self.out_proj(scan_output) - return out - - # fmt: off - def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): - batch_size, seq_len, _ = input_states.shape - dtype = input_states.dtype - - # 1. Gated MLP's linear projection - input_states = apply_mask_to_padding_states(input_states, attention_mask) - projected_states = self.in_proj(input_states) - d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 - _, _, gate, hidden_states_B_C, dt = projected_states.split( - [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 - ) - - # 2. Convolution sequence transformation - if cache_params is not None and cache_position is not None and cache_position[0] > 0: - cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) - - # We need to guarantee that anything regarding the cache is on the same device - conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) - - hidden_states_B_C = torch.sum( - conv_states * self.conv1d.weight.squeeze(1), dim=-1 - ) - if self.use_conv_bias: - hidden_states_B_C = hidden_states_B_C + self.conv1d.bias - hidden_states_B_C = self.act(hidden_states_B_C) - else: - # Init cache - if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) - conv_states = nn.functional.pad( - hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) - ) - cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) - - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) - - hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) - hidden_states, B, C = torch.split( - hidden_states_B_C, - [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], - dim=-1 - ) - - # 3. SSM transformation - A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_position is not None and cache_position[0] > 0: - # We need to guarantee that anything regarding the cache is on the same device - cache_device = cache_params.ssm_states.device - - # Note: there is no need to pad parameter matrices here, as there is just one new token - # for batched generation - dt = dt[:, 0, :][:, None, ...] - dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) - # [num_heads] -> [num_heads, head_dim] - dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) - - dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) - dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) - A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) - # [bsz, num_heads, head_dim, state_size] - dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) - - # Discretize B - # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> - # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] - B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] - B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() - B = B.reshape(batch_size, -1, B.shape[-1]) - # [bsz, num_heads, head_dim, state_size] - dB = dt[..., None] * B[..., None, :] - - # Discretize x into dB - # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] - hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) - dBx = (dB * hidden_states[..., None]).to(device=cache_device) - - # State calculation - cache_params.update_ssm_state( - layer_idx=self.layer_idx, - new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx - ) - - # Subsequent output - # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] - C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] - C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() - C = C.reshape(batch_size, -1, C.shape[-1]) - # [bsz, num_heads, head_dim] - - ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] - # Reshape ssm_states to merge the first two dimensions - ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] - C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] - y = torch.bmm(ssm_states_reshaped, C_reshaped) - y = y.view(batch_size, self.num_heads, self.head_dim) - - # D skip connection - # [num_heads] -> [num_heads, head_dim] - D = self.D[..., None].expand(self.D.shape[0], self.head_dim) - y = (y + hidden_states * D).to(y.dtype) - - # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] - y = y.reshape(batch_size, -1)[:, None, ...] - else: - # begin ssd naive implementation without einsums - dt = nn.functional.softplus(dt + self.dt_bias) - dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) - hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() - B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() - C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) - pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size - - D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) - - # Discretize x and A - hidden_states = hidden_states * dt[..., None] - A = A.to(hidden_states.dtype) * dt - - # Rearrange into blocks/chunks - hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] - - # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] - A = A.permute(0, 3, 1, 2) - A_cumsum = torch.cumsum(A, dim=-1) - - # 1. Compute the output for each intra-chunk (diagonal blocks) - # This is the analog of a causal mask - L = torch.exp(segment_sum(A)) - - # Contraction of C and B to get G (attention-weights like) - G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) - G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) - - # Compute M, equivalent to applying attention mask to weights - M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] - M = M_intermediate.sum(dim=-1) - - # Compute Y_diag (apply to values) - Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) - - # 2. Compute the state for each intra-chunk - # (right term of low-rank factorization of off-diagonal blocks; B terms) - decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) - B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] - states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) - - # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries - # (middle term of factorization of off-diag blocks; A terms) - if cache_params is not None and cache_position is not None and cache_position[0] > 0: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) - else: - previous_states = torch.zeros_like(states[:, :1]) - states = torch.cat([previous_states, states], dim=1) - decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) - decay_chunk = decay_chunk.transpose(1, 3) - new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) - states, ssm_state = new_states[:, :-1], new_states[:, -1] - - # 4. Compute state -> output conversion per chunk - # (left term of low-rank factorization of off-diagonal blocks; C terms) - state_decay_out = torch.exp(A_cumsum) - C_times_states = (C[..., None, :] * states[:, :, None, ...]) - state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) - Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) - - # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) - y = Y_diag + Y_off - # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] - y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) - - y = y + D_residual - # Cutting off padded chunks - if pad_size > 0: - y = y[:, :seq_len, :, :] - y = y.reshape(batch_size, seq_len, -1) - - # Init cache - if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) - - scan_output = self.norm(y, gate) - - # end ssd naive - - # 4. Final linear projection - contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] - return contextualized_states - # fmt: on - - def forward( - self, - hidden_states, - cache_params: Optional[HybridMambaAttentionDynamicCache] = None, - cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ): - if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: - return self.cuda_kernels_forward(hidden_states, cache_params, - cache_position, attention_mask) - dtype = hidden_states.dtype - if attention_mask is not None and attention_mask.shape[ - 1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * - attention_mask[:, :, None]).to(dtype) - - return self.torch_forward(hidden_states, cache_params, cache_position, - attention_mask) - - -class NemotronHRMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - """ - NemotronHRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) - # Weights are in float32 - return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) - - -class NemotronHBlock(nn.Module): - - def __init__(self, config, layer_idx): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.residual_in_fp32 = config.residual_in_fp32 - self.norm = NemotronHRMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - - # M: Mamba2, *: Attention, -: MLP - self.block_type = config.layers_block_type[layer_idx] - if self.block_type == "mamba": - self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx) - elif self.block_type == "attention": - self.mixer = NEMOTRONH_ATTENTION_CLASSES[ - config._attn_implementation](config, layer_idx=layer_idx) - elif self.block_type == "mlp": - self.mixer = NemotronHMLP(config, layer_idx=layer_idx) - elif self.block_type == "moe": - self.mixer = NemotronHMOE(config, layer_idx=layer_idx) - else: - raise ValueError( - f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}" - ) - - def forward( - self, - hidden_states, - cache_params: Optional[HybridMambaAttentionDynamicCache] = None, - cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ): - with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)): - # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs - residual = hidden_states - hidden_states = self.norm( - hidden_states.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - - if self.block_type == "mamba": - hidden_states = self.mixer(hidden_states, - cache_params=cache_params, - cache_position=cache_position) - elif self.block_type == "attention": - hidden_states = self.mixer(hidden_states, - cache_position=cache_position) - hidden_states = hidden_states[0] - elif self.block_type in ["mlp", "moe"]: - hidden_states = self.mixer(hidden_states) - else: - raise ValueError(f"Invalid block_type: {self.block_type}") - - hidden_states = residual + hidden_states - return hidden_states - - -# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH -class NemotronHMLP(nn.Module): - - def __init__(self, - config, - intermediate_size=None, - layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class.") - self.hidden_size = config.hidden_size - self.intermediate_size = intermediate_size or config.intermediate_size - self.up_proj = nn.Linear(self.hidden_size, - self.intermediate_size, - bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, - self.hidden_size, - bias=config.mlp_bias) - self.act_fn = ACT2FN[config.mlp_hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.up_proj(x))) - - -class NemotronHMOE(nn.Module): - - def __init__(self, config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.experts = nn.ModuleList([ - NemotronHMLP(config, - intermediate_size=config.moe_intermediate_size, - layer_idx=layer_idx) - for _ in range(config.n_routed_experts) - ]) - self.gate = NemotronHTopkRouter(config) - self.shared_experts = NemotronHMLP( - config=config, - intermediate_size=config.moe_shared_expert_intermediate_size, - layer_idx=layer_idx) - - def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, - topk_weights: torch.Tensor): - r""" - CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused - to not have to do a loop here (deepseek has 256 experts soooo yeah). - """ - final_hidden_states = torch.zeros_like(hidden_states, - dtype=topk_weights.dtype) - expert_mask = torch.nn.functional.one_hot(topk_indices, - num_classes=len(self.experts)) - expert_mask = expert_mask.permute(2, 0, 1) - - for expert_idx in range(len(self.experts)): - expert = self.experts[expert_idx] - mask = expert_mask[expert_idx] - token_indices, weight_indices = torch.where(mask) - - if token_indices.numel() > 0: - expert_weights = topk_weights[token_indices, weight_indices] - expert_input = hidden_states[token_indices] - expert_output = expert(expert_input) - weighted_output = expert_output * expert_weights.unsqueeze(-1) - final_hidden_states.index_add_(0, token_indices, - weighted_output) - - # in original deepseek, the output of the experts are gathered once we leave this module - # thus the moe module is itelsf an IsolatedParallel module - # and all expert are "local" meaning we shard but we don't gather - return final_hidden_states.type(hidden_states.dtype) - - def forward(self, hidden_states): - residuals = hidden_states - orig_shape = hidden_states.shape - topk_indices, topk_weights = self.gate(hidden_states) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = self.moe(hidden_states, topk_indices, - topk_weights).view(*orig_shape) - hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states - - -class NemotronHTopkRouter(nn.Module): - - def __init__(self, config): - super().__init__() - self.config = config - self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts - self.routed_scaling_factor = config.routed_scaling_factor - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - - self.weight = nn.Parameter( - torch.empty((self.n_routed_experts, config.hidden_size), - dtype=torch.float32)) - self.register_buffer( - "e_score_correction_bias", - torch.zeros(self.n_routed_experts, dtype=torch.float32)) - - @torch.no_grad() - def get_topk_indices(self, scores): - scores_for_choice = scores.view( - -1, - self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) - group_scores = (scores_for_choice.view( - -1, self.n_group, - self.n_routed_experts // self.n_group).topk(2, - dim=-1)[0].sum(dim=-1)) - group_idx = torch.topk(group_scores, - k=self.topk_group, - dim=-1, - sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = (group_mask.unsqueeze(-1).expand( - -1, self.n_group, self.n_routed_experts // self.n_group).reshape( - -1, self.n_routed_experts)) - scores_for_choice = scores_for_choice.masked_fill( - ~score_mask.bool(), 0.0) - topk_indices = torch.topk(scores_for_choice, - k=self.top_k, - dim=-1, - sorted=False)[1] - return topk_indices - - def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) - router_logits = F.linear(hidden_states.type(torch.float32), - self.weight.type(torch.float32)) - scores = router_logits.sigmoid() - topk_indices = self.get_topk_indices(scores) - topk_weights = scores.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, - None, :, :].expand(batch, num_key_value_heads, - n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, - head_dim) - - -class NemotronHAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, - config: NemotronHConfig, - layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class.") - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - if hasattr(config, "head_dim") and config.head_dim is not None: - self.head_dim = config.head_dim - else: - self.head_dim = config.hidden_size // self.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.is_causal = True - - self.q_proj = nn.Linear(self.hidden_size, - self.num_heads * self.head_dim, - bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=config.attention_bias) - self.o_proj = nn.Linear(self.head_dim * self.num_heads, - self.hidden_size, - bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - # position_embeddings: Tuple[torch.Tensor, torch.Tensor], #TODO - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - if past_key_value is not None: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] - - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - #attn_output = attn_output.view(bsz, q_len, self.hidden_size) - attn_output = attn_output.view(bsz, q_len, - self.num_heads * self.head_dim) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba -#class JambaFlashAttention2(JambaAttention): -class NemotronHFlashAttention2(NemotronHAttention): - """ - Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10( - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - if past_key_value is not None: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}.") - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - #attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * - self.head_dim).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba -#class JambaSdpaAttention(JambaAttention): -class NemotronHSdpaAttention(NemotronHAttention): - """ - Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from NemotronHAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "NemotronHModel is using NemotronHSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - if past_key_value is not None: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, :key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -NEMOTRONH_ATTENTION_CLASSES = { - "eager": NemotronHAttention, - "flash_attention_2": NemotronHFlashAttention2, - "sdpa": NemotronHSdpaAttention, -} - - -# Copied from transformers.models.mamba.modeling_mamba2.Mamba2PreTrainedModel -class NemotronHPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = NemotronHConfig - base_model_prefix = "backbone" - _no_split_modules = ["NemotronHBlock"] - supports_gradient_checkpointing = True - _is_stateful = True - - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, NemotronHMamba2Mixer): - module.A_log._no_weight_decay = True - module.D._no_weight_decay = True - - dt = torch.exp( - torch.rand(self.config.mamba_num_heads) * - (math.log(self.config.time_step_max) - - math.log(self.config.time_step_min)) + - math.log(self.config.time_step_min)).clamp( - min=self.config.time_step_floor) - - # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - module.dt_bias.copy_(inv_dt) - module.dt_bias._no_reinit = True - - if isinstance(module, nn.Linear): - if module.bias is not None: - if not getattr(module.bias, "_no_reinit", False): - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=self.config.initializer_range) - - # TODO: Check - if self.config.rescale_prenorm_residual: - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name in ["out_proj.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) - # We need to reinit p since this code could be called multiple times - # Having just p *= scale would repeatedly scale it down - nn.init.kaiming_uniform_(p, a=math.sqrt(5)) - with torch.no_grad(): - p /= math.sqrt(self.config.num_hidden_layers) - - -@dataclass -# Copied from transformers.models.mamba.modeling_mamba2.Mamba2Output with MAMBA2->NemotronH,Mamba2->NemotronH -class NemotronHOutput(ModelOutput): - """ - Class for the NemotronH model outputs. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - cache_params (`HybridMambaAttentionDynamicCache`): - The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to - avoid providing the old `input_ids`. - - Includes both the State space model state matrices after the selective scan, and the Convolutional states - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - last_hidden_state: Optional[torch.FloatTensor] = None - cache_params: Optional[HybridMambaAttentionDynamicCache] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH -class NemotronHCausalLMOutput(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`HybridMambaAttentionDynamicCache`): - The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to - avoid providing the old `input_ids`. - - Includes both the State space model state matrices after the selective scan, and the Convolutional states - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - cache_params: Optional[HybridMambaAttentionDynamicCache] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -NEMOTRONH_START_DOCSTRING = r""" - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`NemotronHConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -NEMOTRONH_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): - Indices of input sequence tokens in the vocabulary. - - If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as - `input_ids`. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - position_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. - cache_params (`HybridMambaAttentionDynamicCache`, *optional*): - If passed along, the model uses the previous state in all the blocks (which will give the output for the - `input_ids` provided as if the model add `state_input_ids + input_ids` as context). - use_cache (`bool`, *optional*): - If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - The position of the current input in the cache. This is used to ensure that the cache is correctly updated. - If `cache_params` is passed, `cache_position` should also be passed. - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) -""" - - -@add_start_docstrings( - "The bare NemotronH Model transformer outputting raw hidden-states without any specific head on top.", - NEMOTRONH_START_DOCSTRING, -) -class NemotronHModel(NemotronHPreTrainedModel): - - def __init__(self, config): - super().__init__(config) - - self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([ - NemotronHBlock(config, layer_idx=idx) - for idx in range(config.num_hidden_layers) - ]) - - self.gradient_checkpointing = False - self.norm_f = NemotronHRMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - # Initialize weights and apply final processing - self._register_load_state_dict_pre_hook(self.load_hook) - self.post_init() - - def load_hook(self, state_dict, prefix, *args): - for k in state_dict: - if "embedding." in k: - state_dict[k.replace("embedding.", - "embeddings.")] = state_dict.pop(k) - break - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, new_embeddings): - self.embeddings = new_embeddings - - @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=NemotronHOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cache_params: Optional[HybridMambaAttentionDynamicCache] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[Tuple, NemotronHOutput]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - # use_cache = use_cache if use_cache is not None else self.config.use_cache - use_cache = use_cache if use_cache is not None else ( - self.config.use_cache if not self.training else False) - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds - is not None): # ^ is python for xor - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.embeddings(input_ids) - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - # From zamba_modeling.py - if use_cache and cache_params is None: - logger.warning_once( - "NemotronH requires an initialized `NemotronHHybridDynamicCache` to return a cache. None was " - "provided, so no cache will be returned.") - - hidden_states = inputs_embeds - - if cache_position is None: - cache_position = torch.arange(hidden_states.shape[1], - device=hidden_states.device) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, - cache_position) - mamba_mask = self._update_mamba_mask(attention_mask, cache_position) - - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - # Until HERE - - for layer_idx, mixer_block in enumerate(self.layers): - # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) - if mixer_block.block_type == "mamba": - layer_mask = mamba_mask - elif mixer_block.block_type == "attention": - layer_mask = causal_mask - elif mixer_block.block_type in ["mlp", "moe"]: - layer_mask = None - else: - raise ValueError(f"Invalid block_type: {self.block_type}") - - if output_hidden_states: - all_hidden_states += (hidden_states, ) - - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, - cache_position, layer_mask) - else: - hidden_states = mixer_block( - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=layer_mask, - ) - - # TODO: Store attentions - # if output_attentions: - # if layer_outputs[1] is not None: - # # append attentions only of attention layers. Mamba layers return `None` as the attention weights - # all_self_attns += (layer_outputs[1],) - - # TODO (Check): should it happen before the forward pass? - # if output_hidden_states: - # all_hidden_states = all_hidden_states + (hidden_states,) - - hidden_states = self.norm_f(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states, ) - - if not return_dict: - return tuple( - v for v in [hidden_states, cache_params, all_hidden_states] - if v is not None) - - return NemotronHOutput( - last_hidden_state=hidden_states, - cache_params=cache_params if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask - def _update_causal_mask(self, attention_mask, input_tensor, cache_position): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - target_length = cache_position[-1] + 1 - - causal_mask = torch.full((sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, - device=device) > cache_position.reshape( - -1, 1) - causal_mask = causal_mask[None, - None, :, :].expand(input_tensor.shape[0], 1, - -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone( - ) # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq( - 0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[ - ..., :mask_length].masked_fill(padding_mask, min_dtype) - - if (self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda"): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended( - causal_mask, min_dtype) - - return causal_mask - - def _update_mamba_mask(self, attention_mask, cache_position): - """ - No need for zeroing states when - 1. Cached forward - 2. Attending to all inputs - """ - mamba_mask = attention_mask - if cache_position[0] > 0 or (attention_mask is not None - and torch.all(attention_mask == 1)): - mamba_mask = None - return mamba_mask - - -@add_start_docstrings( - """ - The NEMOTRONH Model transformer with a language modeling head on top (linear layer with weights not tied to the input - embeddings). - """, - NEMOTRONH_START_DOCSTRING, -) -class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.backbone = NemotronHModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, - config.vocab_size, - bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.backbone.get_input_embeddings() - - def set_input_embeddings(self, new_embeddings): - return self.backbone.set_input_embeddings(new_embeddings) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def get_decoder(self): - return self.model - - def set_decoder(self, decoder): - self.model = decoder - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - **kwargs, - ): - # Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py - # Overwritten -- uses `cache_params` as opposed to `past_key_values` - empty_past_kv = past_key_values is None - - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if not empty_past_kv: - if (inputs_embeds is not None # Exception 1 - or cache_position[-1] >= input_ids.shape[1] # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0]:] - elif input_ids.shape[1] != cache_position.shape[ - 0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - else: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if not empty_past_kv: - position_ids = position_ids[:, -input_ids.shape[1]:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and empty_past_kv: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = { - "input_ids": input_ids.contiguous() - } # `contiguous()` needed for compilation use cases - - model_inputs.update({ - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "logits_to_keep": self.config.num_logits_to_keep, - "cache_position": cache_position, - }) - return model_inputs - - @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=NemotronHCausalLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cache_params: Optional[HybridMambaAttentionDynamicCache] = None, - labels: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, # for now we need this for generation - ) -> Union[Tuple, NemotronHCausalLMOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - nemotron_h_outputs = self.backbone( - input_ids, - cache_params=cache_params, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - use_cache=use_cache, - cache_position=cache_position, - attention_mask=attention_mask, - ) - hidden_states = nemotron_h_outputs[0] - - # TODO: Check zamba_modeling.py: https://github.com/huggingface/transformers/blob/d7188ba600e36d3fd191b12e19f1b3bb81a8404f/src/transformers/models/zamba/modeling_zamba.py#L1284C1-L1286C2 - #logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() - logits = self.lm_head(hidden_states.to( - self.lm_head.weight.dtype)).float() - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1)) - - if not return_dict: - output = (logits, ) + nemotron_h_outputs[1:] - return ((loss, ) + output) if loss is not None else output - - return NemotronHCausalLMOutput( - loss=loss, - logits=logits, - cache_params=nemotron_h_outputs.cache_params, - hidden_states=nemotron_h_outputs.hidden_states, - attentions=nemotron_h_outputs.attentions, - ) diff --git a/simple_nemotron.py b/simple_nemotron.py deleted file mode 100644 index 8313922f3ae..00000000000 --- a/simple_nemotron.py +++ /dev/null @@ -1,731 +0,0 @@ -""" -Simplified NemotronHMamba2Mixer - Tensor Algebra Operations Only - -This file focuses on the tensor algebra operations in the Mamba2 forward pass, -with detailed annotations for parallelization across multiple GPUs. - -Notation: -- b: batch_size -- s: seq_len -- h_in: hidden_size (input) -- h: num_heads -- d: head_dim -- n: ssm_state_size -- g: n_groups -- i: intermediate_size (= h * d) -- c: chunk_size -- num_chunks: number of chunks (= ceil(s / c)) - -Key relationships: -- intermediate_size = num_heads * head_dim (i = h * d) -- conv_dim = intermediate_size + 2 * n_groups * ssm_state_size -""" - -from typing import Optional - -import torch -import torch.nn as nn - - -class NemotronHMamba2Mixer: - """ - Mamba2 SSM Mixer - Tensor Algebra Only - - This class contains only the algebraically significant operations, - annotated with parallelization strategies. - """ - - def __init__(self): - # Model dimensions (example values) - self.hidden_size = 4096 # h_in - self.num_heads = 64 # h - self.head_dim = 64 # d - self.intermediate_size = 4096 # i = h * d - self.n_groups = 8 # g - self.ssm_state_size = 128 # n - self.chunk_size = 256 # c - self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size - - # analogy to transformers' attention: - # A - query [b, s, h, d] - # B - key [b, s, g, d] # n_groups function as num KV heads - # C - value [b, s, g, d] - # D - attention mask - # B and C will be broadcasted from g to h for SSM computation - - # Learnable parameters - conv_kernel = 4 - self.in_proj = nn.Linear( - self.hidden_size, - self.intermediate_size + self.conv_dim + self.num_heads) - self.conv1d = nn.Conv1d( - self.conv_dim, - self.conv_dim, - kernel_size=conv_kernel, - groups=self.conv_dim, - padding=conv_kernel - - 1 # This ensures output length >= input length - ) - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size) - self.A_log = nn.Parameter(torch.randn(self.num_heads)) - self.dt_bias = nn.Parameter(torch.randn(self.num_heads)) - self.D = nn.Parameter(torch.randn(self.num_heads)) - - def segment_sum(self, input_tensor): - """ - Segment sum operation - computes cumulative sum within triangular mask. - - Input: [..., chunk_size] - Output: [..., chunk_size, chunk_size] - - PARALLELIZATION ANALYSIS: - - All batch dimensions (...): FULLY PARALLEL (embarrassingly parallel) - - chunk_size dimension: SEQUENTIAL (cumsum is inherently sequential) - - Can parallelize across chunks if processing multiple chunks - - Cross-GPU: Can distribute batch/head dimensions, but cumsum requires local computation - """ - chunk_size = input_tensor.size(-1) - - # Input: [..., c] -> [..., c, c] - # Complexity: O(c^2) per element in batch - # Parallel: All leading dims are independent - input_tensor = input_tensor[..., None].expand(*input_tensor.size(), - chunk_size) - - # Cumsum along dim=-2 - # Input: [..., c, c], Output: [..., c, c] - # Complexity: O(c^2) per element - # Parallel: Leading dims (...) are independent, but cumsum is sequential in last-2 dim - # WARNING: cumsum is NOT parallelizable in the reduction dimension - tensor_segsum = torch.cumsum(input_tensor, dim=-2) - - return tensor_segsum - - def torch_forward_algebra_only( - self, - input_states: torch.Tensor, # [b, s, h_in] - cache_params: Optional = None, - debug: bool = False): - """ - Forward pass with TENSOR ALGEBRA operations only. - Focus: matrix multiplications, reductions, cumulative operations. - Excluded: element-wise ops (activations, exp, masking, etc.) - """ - - batch_size, seq_len, _ = input_states.shape # b, s, h_in - - # ============================================================================= - # STEP 1: Input Projection (Linear Layer) - # ============================================================================= - # Operation: projected_states = input_states @ in_proj.weight^T + in_proj.bias - # Input: [b, s, h_in] - # Weight: [projection_size, h_in] where projection_size = i + conv_dim + h - # Output: [b, s, projection_size] - # Complexity: O(b * s * h_in * projection_size) ≈ O(b * s * h_in^2) - # - # PARALLELIZATION: - # - Batch (b): FULLY PARALLEL - can split across GPUs with no communication - # - Sequence (s): FULLY PARALLEL - can split across GPUs with no communication - # - Hidden (h_in): PARALLEL with ALL_REDUCE - this is the reduction dimension - # * If split h_in across GPUs, need all_reduce to sum partial results - # * Row-parallel: split weight rows, no all_reduce needed - # * Column-parallel: split weight columns, need all_reduce after - # - Output (projection_size): PARALLEL - row-wise split requires no communication - # - # TENSOR PARALLEL STRATEGIES: - # 1. Batch parallel: Each GPU processes different batch elements - # 2. Sequence parallel: Each GPU processes different tokens (works for attention) - # 3. Tensor parallel (column): Split projection_size, all_reduce on h_in - # 4. Tensor parallel (row): Split h_in, each GPU computes partial projection - projected_states = self.in_proj(input_states) # [b, s, projection_size] - - # Split the projection into components - # gate: [b, s, i], hidden_states_B_C: [b, s, conv_dim], dt: [b, s, h] - # Note: d_mlp is computed but will be 0 in this configuration - d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - - 2 * self.n_groups * self.ssm_state_size - self.num_heads) // 2 - - if debug: - print("\nProjection split:") - print(f" projected_states shape: {projected_states.shape}") - print(f" d_mlp: {d_mlp}") - print(f" Split sizes: [d_mlp={d_mlp}, d_mlp={d_mlp}, " - f"intermediate={self.intermediate_size}, " - f"conv_dim={self.conv_dim}, num_heads={self.num_heads}]") - print( - f" Total: {2*d_mlp + self.intermediate_size + self.conv_dim + self.num_heads}" - ) - - # Split into components (d_mlp will be 0, so first two splits are empty) - splits = [] - current_idx = 0 - for size in [ - d_mlp, d_mlp, self.intermediate_size, self.conv_dim, - self.num_heads - ]: - if size > 0: - splits.append(projected_states[..., - current_idx:current_idx + size]) - else: - splits.append( - projected_states[..., - current_idx:current_idx]) # Empty tensor - current_idx += size - - _, _, gate, hidden_states_B_C, dt = splits[0], splits[1], splits[ - 2], splits[3], splits[4] - - if debug: - print( - f" After split - gate: {gate.shape}, hidden_states_B_C: {hidden_states_B_C.shape}, dt: {dt.shape}" - ) - - # ============================================================================= - # STEP 2: Conv1D Operation - # ============================================================================= - # Conv1D is applied on sequence dimension - # Input: [b, conv_dim, s] (after transpose) - # Weight: [conv_dim, 1, kernel_size] - # Output: [b, conv_dim, s] - # Complexity: O(b * conv_dim * s * kernel_size) - # - # PARALLELIZATION: - # - Batch (b): FULLY PARALLEL - # - Channel (conv_dim): FULLY PARALLEL (depthwise conv, groups=conv_dim) - # - Sequence (s): PARALLEL with communication - # * Conv requires kernel_size-1 halo elements from neighbors - # * Split sequence: need halo exchange between GPUs - # * First/last kernel_size-1 tokens need data from adjacent GPUs - # - # TENSOR PARALLEL STRATEGIES: - # 1. Batch parallel: Easiest, no communication - # 2. Channel parallel: Split conv_dim, no cross-channel communication (depthwise) - # 3. Sequence parallel: Need halo exchange (communication overhead) - hidden_states_B_C_transposed = hidden_states_B_C.transpose( - 1, 2) # [b, conv_dim, s] - conv_out = self.conv1d(hidden_states_B_C_transposed)[ - ..., :seq_len] # [b, conv_dim, s] - hidden_states_B_C = conv_out.transpose(1, 2) # [b, s, conv_dim] - - # Split conv output - # conv_dim = intermediate_size + 2 * n_groups * ssm_state_size - split_sizes = [ - self.intermediate_size, self.n_groups * self.ssm_state_size, - self.n_groups * self.ssm_state_size - ] - - # Verify split sizes match conv_dim - assert sum( - split_sizes - ) == self.conv_dim, f"Split sizes {split_sizes} don't sum to conv_dim {self.conv_dim}" - - hidden_states = hidden_states_B_C[..., :self.intermediate_size] - B = hidden_states_B_C[..., - self.intermediate_size:self.intermediate_size + - self.n_groups * self.ssm_state_size] - C = hidden_states_B_C[..., self.intermediate_size + - self.n_groups * self.ssm_state_size:] - - # hidden_states: [b, s, i], B: [b, s, g*n], C: [b, s, g*n] - if debug: - print( - f"After split - hidden_states: {hidden_states.shape}, B: {B.shape}, C: {C.shape}" - ) - - # ============================================================================= - # STEP 3: SSM State Space Computation (Main Computation) - # ============================================================================= - - # Reshape for SSM computation - # hidden_states: [b, s, i] -> [b, s, h, d] - # B: [b, s, g*n] -> [b, s, g, n] - # C: [b, s, g*n] -> [b, s, g, n] - # Complexity: O(1) - just view operations - # Parallel: All dimensions are independent - - if debug: - print( - f"Before reshape - hidden_states: {hidden_states.shape}, expected: [{batch_size}, {seq_len}, {self.intermediate_size}]" - ) - print( - f"Reshape target: [{batch_size}, {seq_len}, {self.num_heads}, {self.head_dim}]" - ) - print( - f"intermediate_size={self.intermediate_size}, num_heads={self.num_heads}, head_dim={self.head_dim}" - ) - print(f"num_heads * head_dim = {self.num_heads * self.head_dim}") - - # Verify dimensions are compatible - assert hidden_states.shape[-1] == self.num_heads * self.head_dim, \ - f"Cannot reshape {hidden_states.shape} to have {self.num_heads} heads of dim {self.head_dim}" - - hidden_states = hidden_states.reshape(batch_size, seq_len, - self.num_heads, self.head_dim) - B = B.reshape(batch_size, seq_len, self.n_groups, self.ssm_state_size) - C = C.reshape(batch_size, seq_len, self.n_groups, self.ssm_state_size) - - if debug: - print( - f"After reshape - hidden_states: {hidden_states.shape}, B: {B.shape}, C: {C.shape}" - ) - - # Repeat B and C to match num_heads (from n_groups) - # Input: [b, s, g, n] - # Output: [b, s, h, n] where h = g * repetition_factor - # Complexity: O(b * s * h * n) memory, O(1) compute (just indexing) - # Parallel: FULLY PARALLEL - simple replication - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) # [b, s, h, n] - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) # [b, s, h, n] - - # Compute pad size for chunking - pad_size = (self.chunk_size - - seq_len % self.chunk_size) % self.chunk_size - - # ============================================================================= - # STEP 3a: Chunk Reshaping - # ============================================================================= - # Reshape sequences into chunks - # Input: [b, s, h, d] - # Output: [b, num_chunks, c, h, d] where num_chunks = ceil(s/c) - # Complexity: O(1) - reshape only - # Parallel: All dimensions independent - # - # Note: This creates a new dimension (num_chunks) that can be parallelized! - def reshape_into_chunks(tensor, pad_size, chunk_size): - """Pad and reshape into chunks""" - # Pad: increases sequence length by pad_size - # Reshape: [b, s+pad, ...] -> [b, num_chunks, c, ...] - # Parallel: Independent across batch dimension - return tensor # Simplified - actual implementation in original code - - # After chunking (conceptual): - # hidden_states: [b, num_chunks, c, h, d] - # A: [b, h, num_chunks, c] (permuted for computation) - # B: [b, num_chunks, c, h, n] - # C: [b, num_chunks, c, h, n] - - # ============================================================================= - # STEP 3b: Cumulative Sum (Sequential Operation) - # ============================================================================= - # A_cumsum = torch.cumsum(A, dim=-1) - # Input: [b, h, num_chunks, c] - # Output: [b, h, num_chunks, c] - # Complexity: O(b * h * num_chunks * c) - # - # PARALLELIZATION: - # - Batch (b): FULLY PARALLEL - # - Heads (h): FULLY PARALLEL - # - Chunks (num_chunks): FULLY PARALLEL - each chunk is independent! - # - Chunk_size (c): SEQUENTIAL - cumsum is inherently sequential - # - # CRITICAL: cumsum within each chunk is sequential, but different chunks - # can be computed in parallel! This is why chunking is valuable. - # - # TENSOR PARALLEL: Can split b, h, num_chunks across GPUs with NO communication - A_cumsum = torch.zeros(batch_size, self.num_heads, - (seq_len + pad_size) // self.chunk_size, - self.chunk_size) - - # ============================================================================= - # STEP 3c: Segment Sum (calls cumsum internally) - # ============================================================================= - # L = torch.exp(segment_sum(A)) - # segment_sum input: [b, h, num_chunks, c] - # segment_sum output: [b, h, num_chunks, c, c] - # Complexity: O(b * h * num_chunks * c^2) - # - # PARALLELIZATION: - # - Batch (b): FULLY PARALLEL - # - Heads (h): FULLY PARALLEL - # - Chunks (num_chunks): FULLY PARALLEL - # - Within chunk (c): SEQUENTIAL (cumsum) - # - Output dimension (c): Creates new dimension, parallel - # - # TENSOR PARALLEL: Can split b, h, num_chunks across GPUs with NO communication - # The c x c matrix per chunk is computed locally on each GPU - L = self.segment_sum(A_cumsum) # [b, h, num_chunks, c, c] - - # ============================================================================= - # STEP 3d: Attention-like Computation (G matrix) - # ============================================================================= - # G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] - # Input C: [b, num_chunks, c, h, n] - # Input B: [b, num_chunks, c, h, n] - # Output: [b, num_chunks, c, c, h, n] - # Then: G = G_intermediate.sum(dim=-1) -> [b, num_chunks, c, c, h] - # Complexity: O(b * num_chunks * c^2 * h * n) - # - # PARALLELIZATION: - # - Batch (b): FULLY PARALLEL - # - Chunks (num_chunks): FULLY PARALLEL - # - Query positions (c): FULLY PARALLEL - # - Key positions (c): FULLY PARALLEL - # - Heads (h): FULLY PARALLEL - # - State dimension (n): PARALLEL with ALL_REDUCE (this is reduction dim) - # * If split n across GPUs, need all_reduce after sum - # - # TENSOR PARALLEL STRATEGIES: - # 1. Split any of (b, num_chunks, h) with no communication - # 2. Split n with all_reduce after reduction - # 3. This is similar to attention QK^T computation! - C_expanded = torch.zeros(batch_size, - (seq_len + pad_size) // self.chunk_size, - self.chunk_size, self.chunk_size, - self.num_heads, self.ssm_state_size) - B_expanded = torch.zeros(batch_size, - (seq_len + pad_size) // self.chunk_size, - self.chunk_size, self.chunk_size, - self.num_heads, self.ssm_state_size) - G_intermediate = C_expanded * B_expanded # [b, num_chunks, c, c, h, n] - - # Reduction over state dimension - # Input: [b, num_chunks, c, c, h, n] - # Output: [b, num_chunks, c, c, h] - # Complexity: O(b * num_chunks * c^2 * h * n) - # Parallel: Reduction over n - if n is split, need all_reduce - G = G_intermediate.sum(dim=-1) # [b, num_chunks, c, c, h] - - # ============================================================================= - # STEP 3e: Attention Weights Computation (M matrix) - # ============================================================================= - # M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] - # After permute, L: [b, num_chunks, c, c, h] - # G: [b, num_chunks, c, c, h] - # M_intermediate: [b, num_chunks, c, c, h, d] (after broadcasting) - # M = M_intermediate.sum(dim=-1) -> [b, num_chunks, c, c, h, d] - # Complexity: O(b * num_chunks * c^2 * h * d) - # - # PARALLELIZATION: - # - All of (b, num_chunks, c, c, h, d) are independent in the outer product - # - The sum reduction is over a broadcasted dimension - # - FULLY PARALLEL across b, num_chunks, c, c, h - # - d dimension: depends on reduction - L_permuted = L.permute(0, 2, 3, 4, 1) # [b, num_chunks, c, c, h] - M_intermediate = torch.zeros(batch_size, - (seq_len + pad_size) // self.chunk_size, - self.chunk_size, self.chunk_size, - self.num_heads, self.head_dim) - M = M_intermediate.sum( - dim=-1) # Simplified - actual computation more complex - - # ============================================================================= - # STEP 3f: Intra-chunk Output (Y_diag) - # ============================================================================= - # Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) - # M: [b, num_chunks, c, c, h, d] - # hidden_states after chunking: [b, num_chunks, c, h, d] - # After broadcasting: [b, num_chunks, c, c, h, d] - # Output after sum: [b, num_chunks, c, h, d] - # Complexity: O(b * num_chunks * c^2 * h * d) - # - # PARALLELIZATION: - # - Batch (b): FULLY PARALLEL - # - Chunks (num_chunks): FULLY PARALLEL - # - Output positions (c, dim=2): FULLY PARALLEL - # - Input positions (c, dim=3): PARALLEL with ALL_REDUCE (reduction dimension) - # - Heads (h): FULLY PARALLEL - # - Head_dim (d): FULLY PARALLEL - # - # This is essentially the attention "apply to values" step! - # TENSOR PARALLEL: Split b, num_chunks, h with no communication - # If split input c dimension, need all_reduce - Y_diag = torch.zeros(batch_size, - (seq_len + pad_size) // self.chunk_size, - self.chunk_size, self.num_heads, self.head_dim) - - # ============================================================================= - # STEP 3g: Intra-chunk State Computation - # ============================================================================= - # B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] - # states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) - # B_decay: [b, num_chunks, c, h, n] - # hidden_states: [b, num_chunks, c, h, d] - # After broadcasting: [b, num_chunks, c, h, d, n] - # After sum over c: [b, num_chunks, h, d, n] - # Complexity: O(b * num_chunks * c * h * d * n) - # - # PARALLELIZATION: - # - Batch (b): FULLY PARALLEL - # - Chunks (num_chunks): FULLY PARALLEL (each chunk's state independent) - # - Sequence within chunk (c, dim=2): PARALLEL with ALL_REDUCE (reduction) - # - Heads (h): FULLY PARALLEL - # - Head_dim (d): FULLY PARALLEL - # - State_size (n): FULLY PARALLEL - # - # TENSOR PARALLEL: Can split b, num_chunks, h, d, n with no communication - # If split c dimension, need all_reduce after sum - states = torch.zeros(batch_size, - (seq_len + pad_size) // self.chunk_size, - self.num_heads, self.head_dim, self.ssm_state_size) - - # ============================================================================= - # STEP 3h: Inter-chunk Recurrence (Sequential Across Chunks!) - # ============================================================================= - # decay_chunk = torch.exp(segment_sum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) - # decay_chunk = decay_chunk.transpose(1, 3) # [b, num_chunks+1, num_chunks+1, h] - # new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) - # - # Input states: [b, num_chunks, h, d, n] - # decay_chunk: [b, num_chunks+1, num_chunks+1, h] - # new_states: [b, num_chunks+1, h, d, n] - # - # Complexity: O(b * num_chunks^2 * h * d * n) - # - # PARALLELIZATION - CRITICAL INSIGHT: - # - Batch (b): FULLY PARALLEL - # - Heads (h): FULLY PARALLEL - # - Head_dim (d): FULLY PARALLEL - # - State_size (n): FULLY PARALLEL - # - Chunks (num_chunks): SEQUENTIAL!!! This is a recurrence across chunks! - # - # **This is the main sequential bottleneck for long sequences!** - # - # The sum over dim=1 creates a dependency between chunks: - # new_states[chunk_i] depends on states[0:i] - # - # TENSOR PARALLEL STRATEGIES: - # 1. Can split b, h, d, n across GPUs with no communication - # 2. CANNOT efficiently parallelize across chunks without changing algorithm - # 3. For very long sequences, this becomes a bottleneck - # 4. Possible solution: Use ring-reduce or prefix-sum parallel algorithms - # but this requires O(log num_chunks) communication rounds - # - # Alternative: Pipeline parallelism - process chunks sequentially but - # overlap computation of different layers - new_states = torch.zeros(batch_size, - (seq_len + pad_size) // self.chunk_size + 1, - self.num_heads, self.head_dim, - self.ssm_state_size) - - # Extract final state and intermediate states - states = new_states[:, :-1] # [b, num_chunks, h, d, n] - ssm_state = new_states[:, -1] # [b, h, d, n] - final state for caching - - # ============================================================================= - # STEP 3i: State to Output (Y_off) - # ============================================================================= - # C_times_states = (C[..., None, :] * states[:, :, None, ...]) - # Input C: [b, num_chunks, c, h, n] - # Input states: [b, num_chunks, h, d, n] - # After broadcast: [b, num_chunks, c, h, d, n] - # Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) - # After sum over n: [b, num_chunks, c, h, d] - # Complexity: O(b * num_chunks * c * h * d * n) - # - # PARALLELIZATION: - # - Batch (b): FULLY PARALLEL - # - Chunks (num_chunks): FULLY PARALLEL (using precomputed states) - # - Positions (c): FULLY PARALLEL - # - Heads (h): FULLY PARALLEL - # - Head_dim (d): FULLY PARALLEL - # - State_size (n): PARALLEL with ALL_REDUCE (reduction dimension) - # - # TENSOR PARALLEL: Split b, num_chunks, c, h, d with no communication - # Split n requires all_reduce after sum - Y_off = torch.zeros(batch_size, (seq_len + pad_size) // self.chunk_size, - self.chunk_size, self.num_heads, self.head_dim) - - # ============================================================================= - # STEP 3j: Combine Intra-chunk and Inter-chunk Outputs - # ============================================================================= - # y = Y_diag + Y_off - # Both: [b, num_chunks, c, h, d] - # Output: [b, num_chunks, c, h, d] -> [b, s, h, d] -> [b, s, i] - # Complexity: O(b * s * i) for reshape - # Parallel: FULLY PARALLEL (element-wise addition) - y = Y_diag + Y_off # [b, num_chunks, c, h, d] - y = y.reshape(batch_size, -1, self.num_heads, - self.head_dim) # [b, s_padded, h, d] - y = y[:, :seq_len, :, :] # Remove padding: [b, s, h, d] - y = y.reshape(batch_size, seq_len, self.intermediate_size) # [b, s, i] - - # ============================================================================= - # STEP 4: Output Projection (Linear Layer) - # ============================================================================= - # contextualized_states = y @ out_proj.weight^T + out_proj.bias - # Input: [b, s, i] - # Weight: [h_in, i] - # Output: [b, s, h_in] - # Complexity: O(b * s * i * h_in) - # - # PARALLELIZATION: - # - Batch (b): FULLY PARALLEL - # - Sequence (s): FULLY PARALLEL - # - Input dim (i): PARALLEL with ALL_REDUCE (reduction dimension) - # - Output dim (h_in): FULLY PARALLEL (row-parallel) - # - # TENSOR PARALLEL STRATEGIES: - # 1. Column parallel on i: split weight columns, all_reduce after matmul - # 2. Row parallel on h_in: split weight rows, no all_reduce needed - # 3. Typically: in_proj is column-parallel, out_proj is row-parallel - # This minimizes communication (1 all_reduce per layer) - contextualized_states = self.out_proj(y) # [b, s, h_in] - - return contextualized_states - - def summarize_parallelization_strategies(self): - """ - SUMMARY OF PARALLELIZATION STRATEGIES FOR MULTI-GPU DEPLOYMENT - ================================================================ - - DIMENSIONS AND THEIR PARALLELIZABILITY: - - 1. BATCH (b) - EMBARRASSINGLY PARALLEL - - Can split across GPUs with ZERO communication - - Each GPU processes different examples - - Strategy: Data Parallelism - - 2. SEQUENCE (s) - MOSTLY PARALLEL with caveats - - Linear layers: FULLY PARALLEL - - Conv1d: Needs halo exchange (kernel_size-1 elements) - - Attention-like ops: FULLY PARALLEL - - Within chunks: FULLY PARALLEL - - Across chunks: SEQUENTIAL (recurrence) - - Strategy: Sequence Parallelism (limited by chunk recurrence) - - 3. HEADS (h) - FULLY PARALLEL - - All operations independent across heads - - No communication needed - - Strategy: Tensor Parallelism on head dimension - - 4. HEAD_DIM (d) - PARALLEL (no reductions in this dim) - - Can split with no all_reduce - - Strategy: Tensor Parallelism on head_dim - - 5. HIDDEN_DIM (h_in, i) - PARALLEL with ALL_REDUCE - - Linear layers: reduction dimension - - Need all_reduce when splitting this dimension - - Strategy: Tensor Parallelism (column-parallel in, row-parallel out) - - 6. STATE_SIZE (n) - PARALLEL with ALL_REDUCE - - Reduction dimension in attention-like operations - - Need all_reduce when computing G and Y_off - - Strategy: Tensor Parallelism on state dimension - - 7. NUM_CHUNKS - MOSTLY PARALLEL - - Each chunk computation: FULLY PARALLEL - - Chunk recurrence (Step 3h): SEQUENTIAL - - Strategy: Pipeline or sequential processing - - 8. CHUNK_SIZE (c) - MIXED - - Cumsum/segment_sum: SEQUENTIAL within chunk - - Other ops: PARALLEL - - Cannot split within chunk effectively - - RECOMMENDED MULTI-GPU STRATEGIES: - ================================== - - Strategy 1: TENSOR + DATA PARALLEL (Most Common) - ------------------------------------------------- - - Split batch across data-parallel GPUs (no communication) - - Within each data-parallel group, use tensor parallelism: - * Split num_heads across GPUs (no communication in compute) - * Column-parallel in_proj, row-parallel out_proj (1 all_reduce per layer) - - Works well for moderate sequence lengths - - Communication: O(b * s * h_in) per layer for all_reduce - - Strategy 2: SEQUENCE PARALLEL (For Very Long Sequences) - -------------------------------------------------------- - - Split sequence dimension across GPUs - - Requires: - * Halo exchange for conv1d (small overhead) - * Sequential processing of chunk recurrence (pipelined) - - Best for: seq_len >> hidden_size - - Communication: O(conv_kernel * features) for halo + pipeline latency - - Strategy 3: EXPERT PARALLEL (If MOE layers present) - --------------------------------------------------- - - Not shown in this code, but relevant for full model - - Split experts across GPUs - - All-to-all communication for routing - - Strategy 4: PIPELINE PARALLEL (For Very Large Models) - ----------------------------------------------------- - - Split layers across GPUs - - Process micro-batches in pipeline - - Communication: O(b * s * h_in) per pipeline stage boundary - - CRITICAL BOTTLENECKS: - ===================== - - 1. CHUNK RECURRENCE (Step 3h) - - Sequential across chunks - - Cannot parallelize without algorithmic changes - - For long sequences with many chunks, this limits speedup - - Mitigation: Use larger chunk_size (but increases memory) - - 2. CUMSUM OPERATIONS - - Sequential within each chunk - - Limits parallelism to chunk_size granularity - - Cannot split chunk_size dimension across GPUs - - 3. ALL_REDUCE COMMUNICATION - - Required when splitting reduction dimensions - - Latency increases with number of GPUs - - Bandwidth-bound for large tensors - - 4. CONV1D HALO EXCHANGE - - Required for sequence parallelism - - Small overhead but adds latency - - OPTIMAL CONFIGURATION (Example for 8 GPUs): - =========================================== - - Use 4-way tensor parallelism on heads (split 64 heads -> 16 per GPU) - - Use 2-way data parallelism on batch - - Keep sequence on single GPU (if possible) - - If sequence too long: - * Use sequence parallelism with 2-4 way split - * Accept chunk recurrence as sequential bottleneck - - This gives: - - ~4x speedup from tensor parallelism (limited by all_reduce) - - ~2x speedup from data parallelism (perfect scaling) - - Total: ~6-7x speedup on 8 GPUs (75-85% efficiency) - """ - - -def main(): - """ - Example usage showing the tensor shapes through the forward pass. - """ - import sys - - # Check if debug mode is requested - debug = "--debug" in sys.argv - - mixer = NemotronHMamba2Mixer() - - # Example input - batch_size = 4 - seq_len = 1024 - hidden_size = 4096 - - input_states = torch.randn(batch_size, seq_len, hidden_size) - - # Forward pass - output = mixer.torch_forward_algebra_only(input_states, debug=debug) - - print(f"\n{'='*80}") - print("NEMOTRON-H MAMBA2 MIXER - TENSOR ALGEBRA ANALYSIS") - print(f"{'='*80}") - print(f"\nInput shape: {input_states.shape}") - print(f"Output shape: {output.shape}") - print("\nConfiguration:") - print(f" - Batch size: {batch_size}") - print(f" - Sequence length: {seq_len}") - print(f" - Hidden size: {hidden_size}") - print(f" - Num heads: {mixer.num_heads}") - print(f" - Head dim: {mixer.head_dim}") - print(f" - Intermediate size: {mixer.intermediate_size}") - print(f" - Chunk size: {mixer.chunk_size}") - print( - f" - Num chunks: {(seq_len + mixer.chunk_size - 1) // mixer.chunk_size}" - ) - print(f"\n{'='*80}") - print("See docstrings in the code for detailed parallelization analysis.") - print("Run with --debug flag to see intermediate tensor shapes.") - print(f"{'='*80}\n") - - -if __name__ == "__main__": - main() From 678d21934575a58ae659b5419b6bc1525b59123e Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Thu, 23 Oct 2025 03:13:48 -0700 Subject: [PATCH 08/12] cleanup Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 4 +- .../auto_deploy/models/patches/nemotron_h.py | 34 ------ .../auto_deploy/transform/library/sharding.py | 20 ---- .../auto_deploy/utils/sharding_utils.py | 113 ++++++++---------- 4 files changed, 55 insertions(+), 116 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 1acdc118fdf..13f1cf0703f 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -75,8 +75,8 @@ transforms: detect_sharding: stage: sharding simple_shard_only: false - use_sharding_from_factory: true - support_partial_config: true + use_sharding_from_factory: false + support_partial_config: false sharding_dims: ['tp', 'ep', 'bmm'] requires_shape_prop: true # TODO: (hg) need to ensure run_shape_prop after sharding. diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py index aefb66a5e65..72ecd0945e2 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py @@ -12,11 +12,6 @@ from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory from tensorrt_llm._torch.auto_deploy.models.patches.bamba import _bamba_mixer_torch_forward -# from transformers.models.nemotron_h.configuration_nemotron_h import NemotronHConfig - -# Remove this patch after TRT-LLM upgrades to the HF transformers version >= 4.57 -# NemotronHConfig.base_model_tp_plan["layers.*.mlp.c_proj"] = "rowwise" - # Forked from: # https://github.com/state-spaces/mamba/blob/6b32be06d026e170b3fdaf3ae6282c5a6ff57b06/mamba_ssm/ops/triton/layernorm_gated.py @@ -94,34 +89,6 @@ def _nemotron_h_block_forward( return hidden_states -# TODO: we assume experts have no bias for now -def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor): - """ - Uses NemotronH router (returns indices, weights) and dispatches through auto_deploy::torch_moe_nemo - with act_fn='relu2'. Falls back to original forward if any expert has bias. - """ - - residuals = hidden_states - orig_shape = hidden_states.shape - topk_indices, topk_weights = self.gate(hidden_states) - x_flat = hidden_states.view(-1, hidden_states.shape[-1]) - - out_flat = torch.ops.auto_deploy.torch_moe( - x_flat, - topk_indices, - topk_weights, - w1_weight=[e.up_proj.weight for e in self.experts], - w2_weight=[e.down_proj.weight for e in self.experts], - w3_weight=[], - act_fn="relu2", - mlp_style="mlp", - ) - - out = out_flat.view(*orig_shape) - out = out + self.shared_experts(residuals) - return out - - _from_config_original = AutoModelForCausalLM.from_config CUSTOM_MODULE_PATCHES: Dict[str, List[Tuple[str, Callable]]] = { @@ -131,7 +98,6 @@ def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor): ("_update_mamba_mask", _nemotron_h_model_update_mamba_mask), ], "NemotronHBlock": [("forward", _nemotron_h_block_forward)], - "NemotronHMOE": [("forward", _nemotron_h_moe_forward)], } diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 6ab1d024f40..30ed09f718a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -16,7 +16,6 @@ happens automatically via the checkpoint loading hook added in step 2c. """ -import ast import operator import re from collections import defaultdict @@ -310,24 +309,6 @@ def detect_sharding_from_factory_config( num_shards += 1 # we have a match. Get the config for this layer config = tp_plan[key] - # check if config has parameters. - if "(" in config: - config, params_str = config.split("(", 1) - params_str = params_str.rsplit(")", 1)[0] # Remove trailing ) - - try: - # Convert "key" = value to "key": value format for dict parsing - params_str = params_str.replace(" = ", ": ") - # Wrap in braces to make it a dict and parse - config_params = ast.literal_eval("{" + params_str + "}") - except Exception as e: - ad_logger.warning( - f"Failed to parse config params: {params_str}, error: {e}. " - "Using empty config." - ) - config_params = {} - else: - config_params = {} if config == "colwise": sharding_config.tp_transforms.append( TPShardingInfo.from_node( @@ -362,7 +343,6 @@ def detect_sharding_from_factory_config( dist_op=None, min_local_shape=min_local_shape, layer_type=LayerType.MAMBA, - fused_weight_dims=config_params.get("fused_weight_dims"), ) ) num_row_col_shards += 1 diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 03a225ce51a..40b45592370 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -150,75 +150,59 @@ def shard_weight_tensor( Tuple of (sharded_tensor, sharded_shape) """ - # Use custom shard function if provided - if custom_shard_fn is not None: - sharded_weight = custom_shard_fn(weight_tensor) - sharded_shape = sharded_weight.shape - # Register load hook with custom function - gm._register_load_state_dict_pre_hook( - partial( - _load_hook, - f_split=custom_shard_fn, - param_key=param_key, - param_shape=sharded_shape, + def split_tensor( + t: torch.Tensor, + d: int = dim, + r: int = rank, + ws: int = world_size, + min_d_shape: int = min_local_shape, + ) -> torch.Tensor: + # The local tensor shape has to be divisible by min_d_shape + max_split_size = t.shape[d] // min_d_shape + if ws > max_split_size: + num_groups = math.ceil(ws / max_split_size) + ad_logger.debug( + f"World size {ws} is greater than the max split size {max_split_size}. " + + f"Splitting tensor to {num_groups} chunks" ) - ) + return torch.tensor_split(t, max_split_size, dim=d)[r // num_groups] + return torch.tensor_split(t, ws, dim=d)[r] - else: + # Handle fused weights + if fused_weight_dims is not None: + # Split fused weights, apply TP sharding to each, then concatenate back + sharded_weight = torch.cat( + [split_tensor(w) for w in torch.split(weight_tensor, fused_weight_dims, dim=dim)], + dim=dim, + ) - def split_tensor( + # Create a function that applies the same logic for loading + def split_fused_tensor( t: torch.Tensor, + fused_dims: list = fused_weight_dims, d: int = dim, - r: int = rank, - ws: int = world_size, - min_d_shape: int = min_local_shape, ) -> torch.Tensor: - # The local tensor shape has to be divisible by min_d_shape - max_split_size = t.shape[d] // min_d_shape - if ws > max_split_size: - num_groups = math.ceil(ws / max_split_size) - ad_logger.debug( - f"World size {ws} is greater than the max split size {max_split_size}. " - + f"Splitting tensor to {num_groups} chunks" - ) - return torch.tensor_split(t, max_split_size, dim=d)[r // num_groups] - return torch.tensor_split(t, ws, dim=d)[r] - - # Handle fused weights - if fused_weight_dims is not None: - # Split fused weights, apply TP sharding to each, then concatenate back - sharded_weight = torch.cat( - [split_tensor(w) for w in torch.split(weight_tensor, fused_weight_dims, dim=dim)], - dim=dim, + return torch.cat( + [split_tensor(w) for w in torch.split(t, fused_dims, dim=d)], + dim=d, ) - # Create a function that applies the same logic for loading - def split_fused_tensor( - t: torch.Tensor, - fused_dims: list = fused_weight_dims, - d: int = dim, - ) -> torch.Tensor: - return torch.cat( - [split_tensor(w) for w in torch.split(t, fused_dims, dim=d)], - dim=d, - ) - - f_split = split_fused_tensor - else: - sharded_weight = split_tensor(weight_tensor) - f_split = split_tensor - - sharded_shape = sharded_weight.shape - - # Register load hook - gm._register_load_state_dict_pre_hook( - partial( - _load_hook, - f_split=f_split, - param_key=param_key, - param_shape=sharded_shape, - ) + f_split = split_fused_tensor + else: + sharded_weight = split_tensor(weight_tensor) + f_split = split_tensor + + sharded_shape = sharded_weight.shape + + # Register load hook + gm._register_load_state_dict_pre_hook( + partial( + _load_hook, + f_split=f_split, + param_key=param_key, + param_shape=sharded_shape, ) + ) # Update the parameter in the module if update_param: @@ -332,6 +316,9 @@ def _insert_sharded_mamba( "conv1d": split_sizes_2, } + ############################################################## + ############# update conv1d num output channels ############## + ############################################################## conv1d_node = conv1d_nodes[0] # conv1d_node last argument is the number of output channels. # This one is also sharded, so we need to update this parameter @@ -349,6 +336,9 @@ def _insert_sharded_mamba( entry_fused_dims = dims break + ############################################################## + ####### shard the entry_node (the first linear layer) ######## + ############################################################## _insert_sharded_matmul( gm=gm, node=entry_node, @@ -361,6 +351,9 @@ def _insert_sharded_mamba( quantization_cb=quantization_cb, ) + ############################################################## + ############## shard the remaining weights ################### + ############################################################## # Get all weight nodes in the subgraph except for out_proj weight_nodes = [ n From b7e264da477d353e85521567e669961fed9bb208 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Sat, 25 Oct 2025 23:51:08 -0700 Subject: [PATCH 09/12] wip Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../auto_deploy/transform/library/sharding.py | 359 ++++++++++++++++-- .../auto_deploy/utils/sharding_utils.py | 111 +++--- .../library/test_tp_sharding.py | 10 +- 3 files changed, 394 insertions(+), 86 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 30ed09f718a..631925d4a80 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -29,20 +29,23 @@ from ...shim.interface import CachedSequenceInterface from ...utils.logger import ad_logger from ...utils.node_utils import ( + bfs, filtered_nodes, identify_regions_between_residuals, is_fake_quantized_linear_op, is_linear_op, is_op, + subgraph, ) from ...utils.sharding_utils import ( BMMShardingInfo, EPShardingInfo, LayerType, + ParameterUpdateInfo, ShardingConfig, ShardingTransformInfo, SplitDimension, - TPShardingInfo, + WeightShardingInfo, ) from ..interface import ( BaseTransform, @@ -83,7 +86,7 @@ def check_and_apply(transform: ShardingTransformInfo) -> bool: return transform.check_and_apply(gm, node_dict[transform.target_node]) num_matches = 0 - for tp_transform in shared_config.sharding_config.tp_transforms: + for tp_transform in shared_config.sharding_config.weight_sharding_transforms: if check_and_apply(tp_transform): num_matches += 1 for bmm_transform in shared_config.sharding_config.bmm_transforms: @@ -93,6 +96,11 @@ def check_and_apply(transform: ShardingTransformInfo) -> bool: if check_and_apply(ep_transform): num_matches += 1 + # post-sharding cleanup transformations + for update_transform in shared_config.sharding_config.parameter_update_transforms: + if not check_and_apply(update_transform): + ad_logger.warning(f"Invalid parameter update transformation {update_transform}.") + info = TransformInfo( skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False ) @@ -107,11 +115,11 @@ def _append_simple_shard( ) -> None: # for every linear node: # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) - tp_shards: List[TPShardingInfo] = [] + tp_shards: List[WeightShardingInfo] = [] for node_group in nodes_linear.values(): for n in node_group: tp_shards.append( - TPShardingInfo.from_node( + WeightShardingInfo.from_node( n, split_dim=SplitDimension.COLUMN, rank=rank, @@ -120,7 +128,7 @@ def _append_simple_shard( min_local_shape=1, ) ) - sharding_config.tp_transforms.extend(tp_shards) + sharding_config.weight_sharding_transforms.extend(tp_shards) class ShardingTransformConfig(TransformConfig): @@ -166,7 +174,7 @@ def _apply( shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: local_rank, world_size = shared_config.local_rank, shared_config.world_size - # world_size = 2 + world_size = 2 if world_size < 2: ad_logger.info("Skipping sharding for single device") @@ -237,6 +245,235 @@ def _apply( return gm, info +def _process_ssm_sharding( + gm: GraphModule, + entry_node: Node, + rank: int, + world_size: int, + min_local_shape: int = 1, +) -> Tuple[List[WeightShardingInfo], List[ParameterUpdateInfo]]: + """ + Process the SSM sharding from the candidate nodes and update the view and split nodes accordingly. + """ + # Find next linear node to define subgraph boundary + try: + next_lin_node, depth = bfs(entry_node, is_linear_op, include_root=False) + except RuntimeError: + ad_logger.warning("Could not find next linear node after entry_node for Mamba sharding") + return False + + weight_sharding_transforms = [] + parameter_update_transforms = [] + + # Get subgraph between entry_node and next linear node + subgraph_nodes = subgraph([entry_node], [next_lin_node]) + + ############################################################## + ########## infer split sizes for in_proj and conv1d ########## + ############################################################## + # in_proj and conv1d are fused, followed up by split nodes. Infer split sizes: + split_nodes = [ + n + for n in subgraph_nodes + if is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]) + ] + if len(split_nodes) != 2: + ad_logger.warning( + f"Subgraph does not contain exactly two split nodes. " + f"Skipping Mamba sharding. split_nodes={split_nodes}" + ) + return False + split_sizes_1 = split_nodes[0].args[1] + split_sizes_2 = split_nodes[1].args[1] + if split_sizes_1[1] != sum(split_sizes_2): + ad_logger.warning( + f"Split nodes have different sizes. " + f"Skipping Mamba sharding. split_sizes_1={split_sizes_1}, split_sizes_2={split_sizes_2}" + ) + return False + fused_weight_dims = { + "in_proj": split_sizes_1[0:1] + split_sizes_2 + split_sizes_1[2:], + "conv1d": split_sizes_2, + } + + ############################################################## + ############# update conv1d num output channels ############## + ############################################################## + conv1d_nodes = [ + n + for n in subgraph_nodes + if is_op(n, [torch.ops.aten.conv1d, torch.ops.auto_deploy.torch_causal_conv1d]) + ] + assert len(conv1d_nodes) == 1, "Expecting exactly one conv1d node" + conv1d_node = conv1d_nodes[0] + # conv1d_node last argument is the number of output channels. + # This one is also sharded, so we need to update this parameter + conv_args = list(conv1d_node.args) + conv_args[-1] = conv1d_node.args[-1] // world_size + parameter_update_transforms.append( + ParameterUpdateInfo(target_node=conv1d_node.name, args=tuple(conv_args)) + ) + + ############################################################## + ####### shard the entry_node (the first linear layer) ######## + ############################################################## + weight_sharding_transforms.append( + gm=gm, + node=entry_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + fused_weight_dims=fused_weight_dims["in_proj"], + ) + + ############################################################## + ############## shard the remaining weights ################### + ############################################################## + # Get all weight nodes in the subgraph except for out_proj + weight_nodes = [n for n in subgraph_nodes if is_op(n, [torch.ops.aten.get_attr])] + for weight_node in weight_nodes: + weight_key = weight_node.target + # Get the weight parameter + try: + weight_param = gm.get_parameter(weight_key) + except AttributeError: + ad_logger.debug(f"Could not get parameter for {weight_key}, skipping") + continue + + # Get fused dims for this weight if specified + fused_dims = None + for k, v in fused_weight_dims.items(): + if k in weight_key: + fused_dims = v + break + + # Shard the weight tensor (also updates the parameter in the module) + weight_sharding_transforms.append( + gm=gm, + weight_tensor=weight_param, + param_key=weight_key, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + min_local_shape=min_local_shape, + fused_weight_dims=fused_dims, + ) + + ############################################################## + ############## update the view and split nodes ############### + ############################################################## + nodes_to_validate = [ + n for n in subgraph_nodes if is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]) + ] + for view_node in nodes_to_validate: + if len(view_node.args) < 2: + continue + view_shape = list(view_node.args[1]) + if not isinstance(view_shape, list): + continue + if len(view_shape) >= 3 and isinstance(view_shape[2], int) and view_shape[2] != -1: + args = list(view_node.args) + view_shape[2] = view_shape[2] // world_size + args[1] = tuple(view_shape) + parameter_update_transforms.append( + ParameterUpdateInfo(target_node=view_node.name, args=tuple(args)) + ) + ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}") + + split_nodes = [ + n + for n in subgraph_nodes + if is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]) + ] + for split_node in split_nodes: + if len(split_node.args) < 2: + continue + split_sizes = list(split_node.args[1]) + if not isinstance(split_sizes, list): + continue + split_sizes[1] = split_sizes[1] // world_size + split_node.args = tuple(split_sizes) + parameter_update_transforms.append( + ParameterUpdateInfo(target_node=split_node.name, args=tuple(split_node.args)) + ) + ad_logger.debug(f"\nUpdated split node {split_node} arguments to {split_node.args}") + + return weight_sharding_transforms, parameter_update_transforms + + +def _process_column_sharding( + gm: GraphModule, + linear_nodes: List[Node], + rank: int, + world_size: int, + min_local_shape: int = 1, + fused_weight: bool = False, +) -> Tuple[List[WeightShardingInfo], List[ParameterUpdateInfo]]: + """ + Parse the column sharding from the candidate nodes and update the view and split nodes accordingly. + """ + weight_sharding_transforms = [] + parameter_update_transforms = [] + for linear_node in linear_nodes: + weight_sharding_transforms.append( + WeightShardingInfo.from_node( + linear_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, # for column sharding, no dist op is performed + min_local_shape=min_local_shape, + ) + ) + + # get the subgraph of this module. Subgraph boundary is the next linear node. + next_lin_node, depth = bfs(linear_nodes[0], is_linear_op, include_root=False) + subgraph_nodes = subgraph( + [linear_nodes], + [next_lin_node], + ) + + nodes_to_validate = [ + n for n in subgraph_nodes if is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]) + ] + for view_node in nodes_to_validate: + if len(view_node.args) < 2: + continue + view_shape = list(view_node.args[1]) + if not isinstance(view_shape, list): + continue + if len(view_shape) >= 3 and isinstance(view_shape[2], int) and view_shape[2] != -1: + args = list(view_node.args) + view_shape[2] = view_shape[2] // world_size + args[1] = tuple(view_shape) + parameter_update_transforms.append( + ParameterUpdateInfo(target_node=view_node.name, args=tuple(args)) + ) + ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}") + + # if fused_weight_dims is provided, we need to update all split sizes + if fused_weight: + assert len(linear_nodes) == 1, "Fused weight should be only one linear node" + node = linear_nodes[0] + assert world_size is not None, "World size is required to update the split node params" + assert len(node.users) == 1, "Fused linear node should have only one user: a split node" + user = list(node.users)[0] + if is_op(user, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]): + orig_sizes = user.args[1] + new_sizes = [orig_sizes[i] // world_size for i in range(len(orig_sizes))] + args = list(user.args) + args[1] = new_sizes + parameter_update_transforms.append( + ParameterUpdateInfo(target_node=user.name, args=tuple(args)) + ) + ad_logger.debug( + f"\nInserted parameter update transformation for split node {user} arguments to {user.args}" + ) + return weight_sharding_transforms, parameter_update_transforms + + def detect_sharding_from_factory_config( gm: GraphModule, sharding_config: ShardingConfig, @@ -310,8 +547,8 @@ def detect_sharding_from_factory_config( # we have a match. Get the config for this layer config = tp_plan[key] if config == "colwise": - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.COLUMN, rank=rank, @@ -322,8 +559,8 @@ def detect_sharding_from_factory_config( ) num_row_col_shards += 1 elif config == "rowwise": - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.ROW, rank=rank, @@ -334,8 +571,8 @@ def detect_sharding_from_factory_config( ) num_row_col_shards += 1 elif config == "mamba": - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.COLUMN, rank=rank, @@ -355,8 +592,8 @@ def detect_sharding_from_factory_config( if "shared" in module_name: col_row_action = config.replace("local_", "") if col_row_action == "colwise": - sharding_config.tp_transforms.append( - TPShardingInfo( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo( target_node=lin_node.name, split_dim=SplitDimension.COLUMN, rank=rank, @@ -366,8 +603,8 @@ def detect_sharding_from_factory_config( ) ) elif col_row_action == "rowwise": - sharding_config.tp_transforms.append( - TPShardingInfo( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo( target_node=lin_node.name, split_dim=SplitDimension.ROW, rank=rank, @@ -385,8 +622,8 @@ def detect_sharding_from_factory_config( elif "gather" in config: # Simple shard (row + all_gather) - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.COLUMN, rank=rank, @@ -400,8 +637,8 @@ def detect_sharding_from_factory_config( ad_logger.warning( f"Unsupported sharding action {config}. Fallback to simple shard" ) - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.COLUMN, rank=rank, @@ -418,7 +655,7 @@ def detect_sharding_from_factory_config( f"row-col pattern: {num_row_col_shards})" ) - num_matches = len(sharding_config.tp_transforms) + num_matches = len(sharding_config.weight_sharding_transforms) if sharding_config.support_partial_config: ad_logger.info( @@ -515,6 +752,12 @@ def detect_column_row_shard( operator.getitem, } + # SSM nodes (mamba layers) + ssm_nodes = { + torch.ops.auto_deploy.torch_ssm_transform, + torch.ops.auto_deploy.torch_causal_conv1d, + } + # let's look at linear nodes we can identify between pairs of boundary nodes # There is three potential cases we can handle: # 1. No linear nodes: @@ -542,6 +785,8 @@ def detect_column_row_shard( attention_nodes.add(current_node) elif is_op(current_node, shardable_nodes_with_attention): attention_related_nodes.add(current_node) + elif is_op(current_node, ssm_nodes): + ssm_nodes.add(current_node) elif not is_op(current_node, pointwise_ops): unaccounted_nodes.add(current_node) current_node = current_node.next @@ -566,6 +811,29 @@ def detect_column_row_shard( num_simple_shards += 1 continue + if len(ssm_nodes) == 2: + # we expect one input linear node in_proj and one output linear node out_proj + in_proj_node = nodes_linear.values()[0] + out_proj_node = nodes_linear.values()[1] + if len(in_proj_node) == 1 and len(out_proj_node) == 1: + weight_sharding_transforms, parameter_update_transforms = _process_ssm_sharding( + gm, in_proj_node[0], rank, world_size + ) + sharding_config.weight_sharding_transforms.extend(weight_sharding_transforms) + sharding_config.parameter_update_transforms.extend(parameter_update_transforms) + # shard single row node + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( + out_proj_node[0], + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_reduce", + ) + ) + num_row_col_shards += 1 + continue + # let's look at the unnacounted nodes. They are okay as long as they fall before the # first linear node or after the last linear node, i.e., outside the sharded region lin_nodes_flat: Set[Node] = {n for group in nodes_linear.values() for n in group} @@ -614,22 +882,39 @@ def detect_column_row_shard( min_local_shape = attention_nodes.pop().meta["val"].shape[-1] else: min_local_shape = 1 - for i, group in enumerate(nodes_linear.values()): - for n in group: - if i > 0: - dist_op = "all_reduce" - else: - dist_op = None - sharding_config.tp_transforms.append( - TPShardingInfo.from_node( - n, - split_dim=i, - rank=rank, - world_size=world_size, - dist_op=dist_op, - min_local_shape=min_local_shape, - ) - ) + + # We are inserting column-row shard for each group of linear nodes + # This may require parameter update of nodes whose args depend on (sharded) dimensions, + # such as view or split nodes. + nodes_to_column_shard = nodes_linear.values()[0] + nodes_to_row_shard = nodes_linear.values()[1] + if len(nodes_to_row_shard) != 1: + ad_logger.warning( + "Expecting only one linear node for row sharding, but got %s", + len(nodes_to_row_shard), + ) + num_simple_shards += 1 + _append_simple_shard(nodes_to_row_shard, rank, world_size, sharding_config) + continue + + # column-row sharding + weight_sharding_transforms, parameter_update_transforms = _process_column_sharding( + gm, nodes_to_column_shard, rank, world_size, min_local_shape + ) + sharding_config.weight_sharding_transforms.extend(weight_sharding_transforms) + sharding_config.parameter_update_transforms.extend(parameter_update_transforms) + + # shard single row node + sharding_config.weight_sharding_transforms.append( + WeightShardingInfo.from_node( + nodes_to_row_shard[0], + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_reduce", + ) + ) + num_row_col_shards += 1 ad_logger.info( diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 40b45592370..5706055cf6b 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from enum import Enum, IntEnum from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple import torch import torch.nn as nn @@ -339,7 +339,7 @@ def _insert_sharded_mamba( ############################################################## ####### shard the entry_node (the first linear layer) ######## ############################################################## - _insert_sharded_matmul( + _shard_parameter_node( gm=gm, node=entry_node, dim=dim, @@ -405,7 +405,7 @@ def _insert_sharded_mamba( ) -def _insert_sharded_matmul( +def _shard_parameter_node( gm: GraphModule, node: Node, dim: int, @@ -418,7 +418,7 @@ def _insert_sharded_matmul( Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] ] = None, ) -> None: - """Replace the matmul node with a new matmul node that accepts sharded weights. + """Replace the node with parametrized weight tensor with a new node that accepts sharded weights. The state_dict is also updated to contain the sharded weights. """ @@ -487,10 +487,10 @@ def _insert_sharded_matmul( world_size=world_size, ) - # column shard with no gather: the output is sharded - if not add_dist: - _validate_sharded_shapes(node, fused_weight_dims=fused_weight_dims, world_size=world_size) - return + # # # column shard with no gather: the output is sharded + # if not add_dist: + # _validate_sharded_shapes(node, fused_weight_dims=fused_weight_dims, world_size=world_size) + # return # figure out the right dist op dist_lookup = { @@ -506,6 +506,14 @@ def _insert_sharded_matmul( dist_node.replace_input_with(dist_node, node) +def _update_node_args(node: Node, args: tuple) -> None: + """Update the node's arguments with the new sharded arguments.""" + node.args = args + ad_logger.debug( + f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}." + ) + + class SplitDimension(IntEnum): """Enum for tensor split dimensions in sharding.""" @@ -560,20 +568,17 @@ class LayerType(Enum): MOE = "moe" -class TPShardingInfo(ShardingTransformInfo): +class WeightShardingInfo(ShardingTransformInfo): """Configuration for TP sharding transformations.""" split_dim: SplitDimension dist_op: Optional[Literal["all_reduce", "all_gather"]] = None min_local_shape: int = 1 - layer_type: LayerType = LayerType.MLP # used for TP sharding of fused weights - # For MLP/Attention: list of dimensions for fused weights (e.g., [dim1, dim2] for QKV) - # For Mamba: dict mapping weight keys to their fused dimensions - fused_weight_dims: Optional[Union[list, Dict[str, list]]] = None + fused_weight_dims: Optional[list] = None @classmethod - def from_node(cls, node: Node, **kwargs) -> "TPShardingInfo": + def from_node(cls, node: Node, **kwargs) -> "WeightShardingInfo": """ Create the correct TPShardingInfo subclass (FP8/FP4/base) based on `node`. """ @@ -600,30 +605,47 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool: def apply(self, gm: GraphModule, node: Node) -> None: """Apply TP sharding transformation to the graph module.""" - if self.layer_type == LayerType.MAMBA: - _insert_sharded_mamba( - gm=gm, - entry_node=node, - dim=self.split_dim.value, - rank=self.rank, - world_size=self.world_size, - add_dist=self.dist_op is not None, - min_local_shape=self.min_local_shape, - fused_weight_dims=self.fused_weight_dims - if isinstance(self.fused_weight_dims, dict) - else None, - ) - else: - _insert_sharded_matmul( - gm=gm, - node=node, - dim=self.split_dim.value, - rank=self.rank, - world_size=self.world_size, - add_dist=self.dist_op is not None, - min_local_shape=self.min_local_shape, - fused_weight_dims=self.fused_weight_dims, - ) + # if self.layer_type == LayerType.MAMBA: + # _insert_sharded_mamba( + # gm=gm, + # entry_node=node, + # dim=self.split_dim.value, + # rank=self.rank, + # world_size=self.world_size, + # add_dist=self.dist_op is not None, + # min_local_shape=self.min_local_shape, + # fused_weight_dims=self.fused_weight_dims + # if isinstance(self.fused_weight_dims, dict) + # else None, + # ) + # else: + _shard_parameter_node( + gm=gm, + node=node, + dim=self.split_dim.value, + rank=self.rank, + world_size=self.world_size, + add_dist=self.dist_op is not None, + min_local_shape=self.min_local_shape, + fused_weight_dims=self.fused_weight_dims, + ) + + +class ParameterUpdateInfo(ShardingTransformInfo): + """Configuration for node args sharding transformations.""" + + target_node: str + rank: int + world_size: int + args: tuple + + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: + """Validate the transformation configuration.""" + return len(node.args) == len(self.args) + + def apply(self, gm: GraphModule, node: Node) -> None: + """Apply the transformation to the graph module.""" + _update_node_args(node, self.args) class QuantizationShardingMixin(ABC): @@ -687,7 +709,7 @@ def quantization_cb( ) -class FP8TPShardingInfo(QuantizationShardingMixin, TPShardingInfo): +class FP8TPShardingInfo(QuantizationShardingMixin, WeightShardingInfo): """Tensor-parallel sharding for FP8-quantized linears.""" def scale_names(self) -> List[str]: @@ -722,7 +744,7 @@ def shard_load_hook( return def apply(self, gm: GraphModule, node: Node) -> None: - _insert_sharded_matmul( + _shard_parameter_node( gm=gm, node=node, dim=self.split_dim.value, @@ -747,7 +769,7 @@ def _shard_fp4_weight_scale(weight_scale, sharded_uint8_weight_shape, dim, rank, ) -class FP4TPShardingInfo(QuantizationShardingMixin, TPShardingInfo): +class FP4TPShardingInfo(QuantizationShardingMixin, WeightShardingInfo): """Tensor-parallel sharding for FP4-quantized linears.""" def scale_names(self) -> List[str]: @@ -790,7 +812,7 @@ def shard_load_hook( ) def apply(self, gm: GraphModule, node: Node) -> None: - _insert_sharded_matmul( + _shard_parameter_node( gm=gm, node=node, dim=self.split_dim.value, @@ -815,7 +837,7 @@ def _resolve_tp_cls_from_node(node: Node): return cls except Exception: pass - return TPShardingInfo + return WeightShardingInfo class BMMShardingInfo(ShardingTransformInfo): @@ -1177,7 +1199,8 @@ class ShardingConfig(BaseModel): use_sharding_from_factory: bool = False support_partial_config: bool = False sharding_dims: List[str] = Field(default_factory=list) - tp_transforms: List[TPShardingInfo] = Field(default_factory=list) + weight_sharding_transforms: List[WeightShardingInfo] = Field(default_factory=list) + parameter_update_transforms: List[ParameterUpdateInfo] = Field(default_factory=list) bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) ep_transforms: List[EPShardingInfo] = Field(default_factory=list) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 76d48669d61..58855fb0318 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -15,7 +15,7 @@ from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ( SplitDimension, - TPShardingInfo, + WeightShardingInfo, ) from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op @@ -272,7 +272,7 @@ def _run_pattern_detection_job( dim = SplitDimension.COLUMN dist_op = None expected_transformations.append( - TPShardingInfo( + WeightShardingInfo( target_node=node.name, split_dim=dim, rank=rank, @@ -293,7 +293,7 @@ def _run_pattern_detection_job( dim = SplitDimension.ROW dist_op = "all_reduce" expected_transformations.append( - TPShardingInfo( + WeightShardingInfo( target_node=node.name, split_dim=dim, rank=rank, @@ -307,7 +307,7 @@ def _run_pattern_detection_job( for node in gm.graph.nodes: if is_linear_op(node): expected_transformations.append( - TPShardingInfo( + WeightShardingInfo( target_node=node.name, split_dim=SplitDimension.COLUMN, # Simple shard uses dim=0 rank=rank, @@ -351,7 +351,7 @@ def _run_pattern_detection_job( optimizer.shared_config.local_rank = rank optimizer.shared_config.world_size = world_size _ = optimizer(None, gm) - detected_transformations = optimizer.shared_config.sharding_config.tp_transforms + detected_transformations = optimizer.shared_config.sharding_config.weight_sharding_transforms print(f"detected_transformations: {detected_transformations}") print(f"expected_transformations: {expected_transformations}") From ea1b623639f2950df0b5907e0fb56b8190dc8876 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Sun, 26 Oct 2025 12:36:49 -0700 Subject: [PATCH 10/12] WiP Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 2 +- .../_torch/auto_deploy/transform/interface.py | 9 + .../auto_deploy/transform/library/fusion.py | 6 +- .../transform/library/quantization.py | 4 +- .../auto_deploy/transform/library/sharding.py | 295 ++++++++++-------- .../_torch/auto_deploy/utils/node_utils.py | 61 ++-- .../auto_deploy/utils/quantization_utils.py | 4 +- .../auto_deploy/utils/sharding_utils.py | 44 ++- 8 files changed, 253 insertions(+), 172 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 13f1cf0703f..a1c36a205c8 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -77,7 +77,7 @@ transforms: simple_shard_only: false use_sharding_from_factory: false support_partial_config: false - sharding_dims: ['tp', 'ep', 'bmm'] + sharding_dims: ['ep', 'bmm', 'ssm', 'tp'] requires_shape_prop: true # TODO: (hg) need to ensure run_shape_prop after sharding. sharding_transform_executor: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 0bd28a1d78d..81140fccb1b 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -173,6 +173,15 @@ def __and__(self, other: "TransformInfo") -> "TransformInfo": has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes, ) + # implement + addition operator for TransformInfo + def __add__(self, other: "TransformInfo") -> "TransformInfo": + return TransformInfo( + skipped=self.skipped and other.skipped, + num_matches=self.num_matches + other.num_matches, + is_clean=self.is_clean and other.is_clean, + has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes, + ) + TransformHistory = Dict[str, TransformInfo] diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py index 477cde8e02d..e04f3212722 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py @@ -13,7 +13,7 @@ from ...shim.interface import CachedSequenceInterface from ...utils.cuda_mem_tracker import cuda_memory_tracker from ...utils.logger import ad_logger -from ...utils.node_utils import extract_param_names_from_lin_node, is_linear_op, is_op +from ...utils.node_utils import extract_param_names_from_node, is_linear_op, is_op from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry @@ -36,7 +36,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node y2 = y[:, out1:out1+out2] """ # some info we need - keys_unfused = [extract_param_names_from_lin_node(n)[0] for n in linear_nodes] + keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes] params_unfused = [gm.get_parameter(k) for k in keys_unfused] sizes_unfused = [p.size(0) for p in params_unfused] key_fused = f"fused_weight_{idx}" @@ -128,7 +128,7 @@ def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple def _insert_fused_quant_gemm( self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node] ): - keys_unfused = [extract_param_names_from_lin_node(n)[0] for n in linear_nodes] + keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes] params_unfused = [gm.get_parameter(k) for k in keys_unfused] sizes_unfused = [p.size(0) for p in params_unfused] key_fused = f"fused_weight_{idx}" diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 94137e9a0b1..9f53a9bd637 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -14,7 +14,7 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import ( - extract_param_names_from_lin_node, + extract_param_names_from_node, get_quantization_params_from_linear_node, is_bmm_op, is_linear_op, @@ -136,7 +136,7 @@ def _insert_quantized_linear( The state_dict is also updated to contain the sharded weights. """ - param_name, _ = extract_param_names_from_lin_node(node) + param_name, _ = extract_param_names_from_node(node) original_weight = gm.get_parameter(param_name) new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False) modname, _, attrname = param_name.rpartition(".") diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 631925d4a80..7e87469f9b3 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -107,28 +107,30 @@ def check_and_apply(transform: ShardingTransformInfo) -> bool: return gm, info -def _append_simple_shard( +def _process_simple_shard( nodes_linear: Dict[Node, List[Node]], rank: int, world_size: int, sharding_config: ShardingConfig, -) -> None: +) -> int: # for every linear node: # --> row_split (dim 0 of weight) + all_gather (dim -1 of output) - tp_shards: List[WeightShardingInfo] = [] + num_simple_shards = 0 for node_group in nodes_linear.values(): for n in node_group: - tp_shards.append( - WeightShardingInfo.from_node( - n, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op="all_gather", - min_local_shape=1, + num_simple_shards += int( + sharding_config.add( + WeightShardingInfo.from_node( + n, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) ) ) - sharding_config.weight_sharding_transforms.extend(tp_shards) + return num_simple_shards class ShardingTransformConfig(TransformConfig): @@ -217,6 +219,12 @@ def _apply( tp_info = TransformInfo( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) + if "ssm" in sharding_config.sharding_dims: + ssm_info = detect_ssm_shard(gm, sharding_config) + else: + ssm_info = TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) # run EP sharding across ranks if "ep" in sharding_config.sharding_dims: @@ -234,39 +242,30 @@ def _apply( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - info = TransformInfo( - skipped=tp_info.skipped and ep_info.skipped and dp_bmm_info.skipped, - num_matches=tp_info.num_matches + ep_info.num_matches + dp_bmm_info.num_matches, - is_clean=tp_info.is_clean and ep_info.is_clean and dp_bmm_info.is_clean, - has_valid_shapes=tp_info.has_valid_shapes - and ep_info.has_valid_shapes - and dp_bmm_info.has_valid_shapes, - ) + info = tp_info + ssm_info + ep_info + dp_bmm_info return gm, info def _process_ssm_sharding( gm: GraphModule, entry_node: Node, + sharding_config: ShardingConfig, rank: int, world_size: int, min_local_shape: int = 1, -) -> Tuple[List[WeightShardingInfo], List[ParameterUpdateInfo]]: +) -> int: """ Process the SSM sharding from the candidate nodes and update the view and split nodes accordingly. """ # Find next linear node to define subgraph boundary try: - next_lin_node, depth = bfs(entry_node, is_linear_op, include_root=False) + out_proj_node, depth = bfs(entry_node, is_linear_op, include_root=False) except RuntimeError: ad_logger.warning("Could not find next linear node after entry_node for Mamba sharding") - return False - - weight_sharding_transforms = [] - parameter_update_transforms = [] + return 0 # Get subgraph between entry_node and next linear node - subgraph_nodes = subgraph([entry_node], [next_lin_node]) + subgraph_nodes = subgraph([entry_node], [out_proj_node]) ############################################################## ########## infer split sizes for in_proj and conv1d ########## @@ -282,20 +281,44 @@ def _process_ssm_sharding( f"Subgraph does not contain exactly two split nodes. " f"Skipping Mamba sharding. split_nodes={split_nodes}" ) - return False - split_sizes_1 = split_nodes[0].args[1] - split_sizes_2 = split_nodes[1].args[1] - if split_sizes_1[1] != sum(split_sizes_2): + return 0 + split_sizes_0 = split_nodes[0].args[1] + split_sizes_1 = split_nodes[1].args[1] + if split_sizes_0[1] != sum(split_sizes_1): ad_logger.warning( f"Split nodes have different sizes. " - f"Skipping Mamba sharding. split_sizes_1={split_sizes_1}, split_sizes_2={split_sizes_2}" + f"Skipping Mamba sharding. split_sizes_1={split_sizes_0}, split_sizes_2={split_sizes_1}" ) - return False + return 0 fused_weight_dims = { - "in_proj": split_sizes_1[0:1] + split_sizes_2 + split_sizes_1[2:], - "conv1d": split_sizes_2, + "in_proj": split_sizes_0[0:1] + split_sizes_1 + split_sizes_0[2:], + "conv1d": split_sizes_1, } + ############################################################## + ############## update split nodes ############################ + ############################################################## + split_args_0 = list(split_nodes[0].args) + split_args_0[1] = [s // world_size for s in split_args_0[1]] + split_args_1 = list(split_nodes[1].args) + split_args_1[1] = [s // world_size for s in split_args_1[1]] + sharding_config.add( + ParameterUpdateInfo( + rank=rank, + world_size=world_size, + target_node=split_nodes[0].name, + args=tuple(split_args_0), + ) + ) + sharding_config.add( + ParameterUpdateInfo( + rank=rank, + world_size=world_size, + target_node=split_nodes[1].name, + args=tuple(split_args_1), + ) + ) + ############################################################## ############# update conv1d num output channels ############## ############################################################## @@ -310,34 +333,38 @@ def _process_ssm_sharding( # This one is also sharded, so we need to update this parameter conv_args = list(conv1d_node.args) conv_args[-1] = conv1d_node.args[-1] // world_size - parameter_update_transforms.append( - ParameterUpdateInfo(target_node=conv1d_node.name, args=tuple(conv_args)) + sharding_config.add( + ParameterUpdateInfo( + rank=rank, world_size=world_size, target_node=conv1d_node.name, args=tuple(conv_args) + ) ) ############################################################## ####### shard the entry_node (the first linear layer) ######## ############################################################## - weight_sharding_transforms.append( - gm=gm, - node=entry_node, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - dist_op=None, - min_local_shape=min_local_shape, - fused_weight_dims=fused_weight_dims["in_proj"], + sharding_config.add( + WeightShardingInfo.from_node( + entry_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + fused_weight_dims=fused_weight_dims["in_proj"], + ) ) ############################################################## ############## shard the remaining weights ################### ############################################################## - # Get all weight nodes in the subgraph except for out_proj - weight_nodes = [n for n in subgraph_nodes if is_op(n, [torch.ops.aten.get_attr])] + # Get all weight nodes in the subgraph except for out_proj (it has to be row-sharded) + # weight_nodes = [n for n in subgraph_nodes if is_op(n, [torch.ops.aten.get_attr])] + weight_nodes = [n for n in subgraph_nodes if n.op == "get_attr" and "out_proj" not in n.target] for weight_node in weight_nodes: weight_key = weight_node.target # Get the weight parameter try: - weight_param = gm.get_parameter(weight_key) + gm.get_parameter(weight_key) except AttributeError: ad_logger.debug(f"Could not get parameter for {weight_key}, skipping") continue @@ -350,19 +377,20 @@ def _process_ssm_sharding( break # Shard the weight tensor (also updates the parameter in the module) - weight_sharding_transforms.append( - gm=gm, - weight_tensor=weight_param, - param_key=weight_key, - split_dim=SplitDimension.COLUMN, - rank=rank, - world_size=world_size, - min_local_shape=min_local_shape, - fused_weight_dims=fused_dims, + sharding_config.add( + WeightShardingInfo.from_node( + list(weight_node.users)[0], + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + fused_weight_dims=fused_dims, + ) ) ############################################################## - ############## update the view and split nodes ############### + ############## update the view and reshape nodes ############# ############################################################## nodes_to_validate = [ n for n in subgraph_nodes if is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]) @@ -377,47 +405,42 @@ def _process_ssm_sharding( args = list(view_node.args) view_shape[2] = view_shape[2] // world_size args[1] = tuple(view_shape) - parameter_update_transforms.append( - ParameterUpdateInfo(target_node=view_node.name, args=tuple(args)) + sharding_config.add( + ParameterUpdateInfo( + rank=rank, world_size=world_size, target_node=view_node.name, args=tuple(args) + ) ) ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}") - split_nodes = [ - n - for n in subgraph_nodes - if is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]) - ] - for split_node in split_nodes: - if len(split_node.args) < 2: - continue - split_sizes = list(split_node.args[1]) - if not isinstance(split_sizes, list): - continue - split_sizes[1] = split_sizes[1] // world_size - split_node.args = tuple(split_sizes) - parameter_update_transforms.append( - ParameterUpdateInfo(target_node=split_node.name, args=tuple(split_node.args)) + ############################################################## + ############## shard the out_proj node ####################### + ############################################################## + sharding_config.add( + WeightShardingInfo.from_node( + out_proj_node, + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_reduce", ) - ad_logger.debug(f"\nUpdated split node {split_node} arguments to {split_node.args}") - - return weight_sharding_transforms, parameter_update_transforms + ) + return 1 def _process_column_sharding( gm: GraphModule, linear_nodes: List[Node], + sharding_config: ShardingConfig, rank: int, world_size: int, min_local_shape: int = 1, fused_weight: bool = False, -) -> Tuple[List[WeightShardingInfo], List[ParameterUpdateInfo]]: +) -> None: """ Parse the column sharding from the candidate nodes and update the view and split nodes accordingly. """ - weight_sharding_transforms = [] - parameter_update_transforms = [] for linear_node in linear_nodes: - weight_sharding_transforms.append( + sharding_config.add( WeightShardingInfo.from_node( linear_node, split_dim=SplitDimension.COLUMN, @@ -448,8 +471,10 @@ def _process_column_sharding( args = list(view_node.args) view_shape[2] = view_shape[2] // world_size args[1] = tuple(view_shape) - parameter_update_transforms.append( - ParameterUpdateInfo(target_node=view_node.name, args=tuple(args)) + sharding_config.add( + ParameterUpdateInfo( + rank=rank, world_size=world_size, target_node=view_node.name, args=tuple(args) + ) ) ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}") @@ -465,13 +490,14 @@ def _process_column_sharding( new_sizes = [orig_sizes[i] // world_size for i in range(len(orig_sizes))] args = list(user.args) args[1] = new_sizes - parameter_update_transforms.append( - ParameterUpdateInfo(target_node=user.name, args=tuple(args)) + sharding_config.add( + ParameterUpdateInfo( + rank=rank, world_size=world_size, target_node=user.name, args=tuple(args) + ) ) ad_logger.debug( f"\nInserted parameter update transformation for split node {user} arguments to {user.args}" ) - return weight_sharding_transforms, parameter_update_transforms def detect_sharding_from_factory_config( @@ -688,6 +714,41 @@ def detect_sharding_from_factory_config( ) +def detect_ssm_shard( + gm: GraphModule, + sharding_config: ShardingConfig, +) -> TransformInfo: + """A transformation to apply sharding to the model following SSM parallelism. + TODO: This is a TEMPORARY place for this logic due to the incompatibility between the + identify_regions_between_residuals() and subgraph() methods to detect layers. + The goal is to have a unified single pass over the graph to detect layers and apply + appropriate sharding transformations. + """ + rank, world_size = sharding_config.rank, sharding_config.world_size + if world_size < 2: + ad_logger.info("Skipping TP sharding for single device") + return TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) + ad_logger.info("Running SSM sharding detection") + + # find all ssm nodes in the graph + ssm_nodes = filtered_nodes(gm.graph.nodes, ops=torch.ops.auto_deploy.torch_ssm) + num_ssm_shards = 0 + for ssm_node in ssm_nodes: + # We assume that one ssm node defines a subgraph corresponding + # to a single Mamba layer. + # Find defining previous (in_proj) and next (out_proj) linear nodes. + in_proj_node, _ = bfs(ssm_node, is_linear_op, attr_next="args", include_root=False) + + num_ssm_shards += int( + _process_ssm_sharding(gm, in_proj_node, sharding_config, rank, world_size) + ) + + ad_logger.info(f"Found {num_ssm_shards} SSM shards") + return TransformInfo( + skipped=False, num_matches=num_ssm_shards, is_clean=False, has_valid_shapes=False + ) + + def detect_column_row_shard( gm: GraphModule, sharding_config: ShardingConfig, @@ -752,12 +813,6 @@ def detect_column_row_shard( operator.getitem, } - # SSM nodes (mamba layers) - ssm_nodes = { - torch.ops.auto_deploy.torch_ssm_transform, - torch.ops.auto_deploy.torch_causal_conv1d, - } - # let's look at linear nodes we can identify between pairs of boundary nodes # There is three potential cases we can handle: # 1. No linear nodes: @@ -785,8 +840,6 @@ def detect_column_row_shard( attention_nodes.add(current_node) elif is_op(current_node, shardable_nodes_with_attention): attention_related_nodes.add(current_node) - elif is_op(current_node, ssm_nodes): - ssm_nodes.add(current_node) elif not is_op(current_node, pointwise_ops): unaccounted_nodes.add(current_node) current_node = current_node.next @@ -800,40 +853,19 @@ def detect_column_row_shard( if sharding_config.simple_shard_only: ad_logger.debug(f"Forcing Simple Shard: Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += _process_simple_shard( + nodes_linear, rank, world_size, sharding_config + ) continue # simple shard when we have != 2 groups of linear nodes if len(nodes_linear) != 2: ad_logger.debug(f"Linear groups: {nodes_linear}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += _process_simple_shard( + nodes_linear, rank, world_size, sharding_config + ) continue - if len(ssm_nodes) == 2: - # we expect one input linear node in_proj and one output linear node out_proj - in_proj_node = nodes_linear.values()[0] - out_proj_node = nodes_linear.values()[1] - if len(in_proj_node) == 1 and len(out_proj_node) == 1: - weight_sharding_transforms, parameter_update_transforms = _process_ssm_sharding( - gm, in_proj_node[0], rank, world_size - ) - sharding_config.weight_sharding_transforms.extend(weight_sharding_transforms) - sharding_config.parameter_update_transforms.extend(parameter_update_transforms) - # shard single row node - sharding_config.weight_sharding_transforms.append( - WeightShardingInfo.from_node( - out_proj_node[0], - split_dim=SplitDimension.ROW, - rank=rank, - world_size=world_size, - dist_op="all_reduce", - ) - ) - num_row_col_shards += 1 - continue - # let's look at the unnacounted nodes. They are okay as long as they fall before the # first linear node or after the last linear node, i.e., outside the sharded region lin_nodes_flat: Set[Node] = {n for group in nodes_linear.values() for n in group} @@ -861,8 +893,9 @@ def detect_column_row_shard( # check if any unaccounted nodes are left. If so, do a simply shard if unaccounted_nodes or attention_related_nodes: ad_logger.debug(f"Unaccounted nodes: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += _process_simple_shard( + nodes_linear, rank, world_size, sharding_config + ) continue # If we can account for all sharded nodes, we can do a two-way shard @@ -874,8 +907,9 @@ def detect_column_row_shard( # Column-row shard boundary region detection is probably wrong - there should be # only one attention operation. Fall back to simple shard. ad_logger.debug(f"More than one attention node: {unaccounted_nodes}") - _append_simple_shard(nodes_linear, rank, world_size, sharding_config) - num_simple_shards += 1 + num_simple_shards += _process_simple_shard( + nodes_linear, rank, world_size, sharding_config + ) continue # Extract head dimension. We cannot shard below the head_dim size. # Assume that head_dim is the last (innermost) dimension of the tensor @@ -883,7 +917,7 @@ def detect_column_row_shard( else: min_local_shape = 1 - # We are inserting column-row shard for each group of linear nodes + # We are inserting column-row shard for each group of linear enodes # This may require parameter update of nodes whose args depend on (sharded) dimensions, # such as view or split nodes. nodes_to_column_shard = nodes_linear.values()[0] @@ -893,8 +927,9 @@ def detect_column_row_shard( "Expecting only one linear node for row sharding, but got %s", len(nodes_to_row_shard), ) - num_simple_shards += 1 - _append_simple_shard(nodes_to_row_shard, rank, world_size, sharding_config) + num_simple_shards += _process_simple_shard( + nodes_linear, rank, world_size, sharding_config + ) continue # column-row sharding diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 11dc80ac28d..877256bc608 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -106,10 +106,10 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node): return input_params, weight_params, output_params -def extract_weight_node(mm_node: Node) -> int: - """Extracts the weight node from the given linear or BMM node. We assume torch.bmm(activation, weight)""" +def extract_weight_node(node: Node) -> int: + """Extracts the weight node from the given parametrized node""" - def find_get_attr_node(node: Node) -> Node: + def find_get_attr_node(weight_node: Node) -> Node: """Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op.""" # If node is a get_attr node return node # List of nodes allowed in between a get_attr node and the matmul node @@ -118,40 +118,47 @@ def find_get_attr_node(node: Node) -> Node: torch.ops.aten.view.default, } - if node.op == "get_attr": - return node + if weight_node.op == "get_attr": + return weight_node # If node is not in the list of allowable ops then return None - if node.target not in allowed_ops: + if weight_node.target not in allowed_ops: return None - for input_node in node.all_input_nodes: + for input_node in weight_node.all_input_nodes: result = find_get_attr_node(input_node) if result: return result return None - weight_node = mm_node.args[1] + if is_op(node, torch.ops.aten.bmm): + weight_node = node.args[1] + # for other parametrized nodes, we need to find the weight node + else: + weight_nodes = [n for n in node.args if isinstance(n, Node) and n.op == "get_attr"] + # can be two weights (if bias weight is present) + assert len(weight_nodes) >= 1, "Expected exactly one weight node in the parametrized node" + weight_node = weight_nodes[0] # for modelopt quantized graph, there will be a quantize_op - _, weight_params, _ = get_quantization_params_from_linear_node(mm_node) + _, weight_params, _ = get_quantization_params_from_linear_node(node) weight_node = weight_params.input_node if weight_params else weight_node return find_get_attr_node(weight_node) -def num_users_of_weight_node(mm_node: Node) -> int: - """Returns the number of users of the weight node of the given matmul node.""" - weight_node = extract_weight_node(mm_node) +def num_users_of_weight_node(node: Node) -> int: + """Returns the number of users of the weight node of the given parametrized node.""" + weight_node = extract_weight_node(node) return len(weight_node.users) if weight_node is not None else 0 -def extract_param_names_from_lin_node(mm_node: Node) -> Tuple[str, Optional[str]]: - """Extracts the name of the parameter associated with the given matmul node. +def extract_param_names_from_node(node: Node) -> Tuple[str, Optional[str]]: + """Extracts the name of the parameter associated with the given parametrized node. Args: - mm_node: Matmul node in the graph. + node: node with weight parameters in the graph. """ - weight_node = extract_weight_node(mm_node) + weight_node = extract_weight_node(node) assert weight_node, "Cannot identify weight parameter of linear node." @@ -159,7 +166,14 @@ def extract_param_names_from_lin_node(mm_node: Node) -> Tuple[str, Optional[str] weight_name = weight_node.target # check for bias - bias_node = mm_node.args[2] if len(mm_node.args) > 2 else None + if is_op(node, torch.ops.aten.bmm): + bias_node = node.args[2] if len(node.args) > 2 else None + else: + weight_nodes = [n for n in node.args if isinstance(n, Node) and n.op == "get_attr"] + if len(weight_nodes) > 1: + bias_node = weight_nodes[1] + else: + bias_node = None assert bias_node is None or bias_node.op == "get_attr" bias_name = bias_node.target if bias_node is not None else None @@ -386,12 +400,13 @@ def bfs( continue # Skip the boundary node. if target(cur_node) and (include_root or depth > 0): return cur_node, depth - for next_node in getattr(cur_node, attr_next): - if boundary is not None and next_node == boundary: - continue # Do not expand past the boundary. - if next_node not in visited: - visited.add(next_node) - queue_at_depth_next.append(next_node) + if hasattr(cur_node, attr_next): + for next_node in getattr(cur_node, attr_next): + if boundary is not None and next_node == boundary: + continue # Do not expand past the boundary. + if next_node not in visited: + visited.add(next_node) + queue_at_depth_next.append(next_node) if not queue_at_depth: queue_at_depth = queue_at_depth_next queue_at_depth_next = [] diff --git a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py index 90e6b380338..aee98c37713 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py @@ -8,7 +8,7 @@ from ..custom_ops.quant import FP4_GLOBAL_SCALE_MAX, FP8_MAX from .logger import ad_logger from .node_utils import ( - extract_param_names_from_lin_node, + extract_param_names_from_node, get_quantization_params_from_linear_node, is_bmm_op, is_linear_op, @@ -117,7 +117,7 @@ def should_skip_quantization( else: if not (is_linear_op(node_or_name) or is_bmm_op(node_or_name)): return True - param_name, _ = extract_param_names_from_lin_node(node_or_name) + param_name, _ = extract_param_names_from_node(node_or_name) modname, _, _ = param_name.rpartition(".") return any(fnmatch(modname, pattern) for pattern in excluded_patterns) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 5706055cf6b..1c76d602ee8 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -17,7 +17,7 @@ from ..utils.logger import ad_logger from .node_utils import ( bfs, - extract_param_names_from_lin_node, + extract_param_names_from_node, is_linear_op, is_op, num_users_of_weight_node, @@ -171,10 +171,10 @@ def split_tensor( # Handle fused weights if fused_weight_dims is not None: # Split fused weights, apply TP sharding to each, then concatenate back - sharded_weight = torch.cat( - [split_tensor(w) for w in torch.split(weight_tensor, fused_weight_dims, dim=dim)], - dim=dim, - ) + # sharded_weight = torch.cat( + # [split_tensor(w) for w in torch.split(weight_tensor, fused_weight_dims, dim=dim)], + # dim=dim, + # ) # Create a function that applies the same logic for loading def split_fused_tensor( @@ -189,9 +189,10 @@ def split_fused_tensor( f_split = split_fused_tensor else: - sharded_weight = split_tensor(weight_tensor) + # sharded_weight = split_tensor(weight_tensor) f_split = split_tensor + sharded_weight = f_split(weight_tensor) sharded_shape = sharded_weight.shape # Register load hook @@ -330,7 +331,7 @@ def _insert_sharded_mamba( # Extract entry node's fused_weight_dims by matching weight name against patterns entry_fused_dims = None if fused_weight_dims: - entry_weight_key, _ = extract_param_names_from_lin_node(entry_node) + entry_weight_key, _ = extract_param_names_from_node(entry_node) for pattern, dims in fused_weight_dims.items(): if re.search(pattern, entry_weight_key): entry_fused_dims = dims @@ -432,7 +433,7 @@ def _shard_parameter_node( ) return # get weight and bias key - weight_key, bias_key = extract_param_names_from_lin_node(node) + weight_key, bias_key = extract_param_names_from_node(node) modname = weight_key.rpartition(".")[0] submod = gm.get_submodule(modname) @@ -488,9 +489,9 @@ def _shard_parameter_node( ) # # # column shard with no gather: the output is sharded - # if not add_dist: - # _validate_sharded_shapes(node, fused_weight_dims=fused_weight_dims, world_size=world_size) - # return + if not add_dist: + _validate_sharded_shapes(node, fused_weight_dims=fused_weight_dims, world_size=world_size) + return # figure out the right dist op dist_lookup = { @@ -1204,6 +1205,15 @@ class ShardingConfig(BaseModel): bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) ep_transforms: List[EPShardingInfo] = Field(default_factory=list) + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._transform_list_dict = { + WeightShardingInfo: self.weight_sharding_transforms, + BMMShardingInfo: self.bmm_transforms, + EPShardingInfo: self.ep_transforms, + ParameterUpdateInfo: self.parameter_update_transforms, + } + @model_validator(mode="after") def _validate_and_normalize(self): # Normalize empty dict to None for "no config" @@ -1214,6 +1224,18 @@ def _validate_and_normalize(self): self.validate_config() return self + def add(self, transform: ShardingTransformInfo) -> bool: + """Append a TP transform only if that node was + not sharded before. Do not overwrite existing transforms. + """ + # try to add to appropriate transformation list + transform_list = self._transform_list_dict[type(transform)] + for existing_transform in transform_list: + if existing_transform.target_node == transform.target_node: + return False + transform_list.append(transform) + return True + def validate_config(self) -> bool: if self.factory_source != ShardingConfigSource.HUGGINGFACE: ad_logger.warning( From d7511ef71d5a18fb7e0e2d67c70c948c9f06c8a9 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Sun, 26 Oct 2025 17:01:00 -0700 Subject: [PATCH 11/12] WiP Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 2 +- .../auto_deploy/transform/library/sharding.py | 99 ++++---- .../auto_deploy/utils/sharding_utils.py | 237 +++--------------- 3 files changed, 78 insertions(+), 260 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index a1c36a205c8..7a177b0751c 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -77,7 +77,7 @@ transforms: simple_shard_only: false use_sharding_from_factory: false support_partial_config: false - sharding_dims: ['ep', 'bmm', 'ssm', 'tp'] + sharding_dims: ['ssm'] #, 'tp'] #, 'ep', 'bmm'] requires_shape_prop: true # TODO: (hg) need to ensure run_shape_prop after sharding. sharding_transform_executor: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 7e87469f9b3..16aa2c3c612 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -43,6 +43,8 @@ LayerType, ParameterUpdateInfo, ShardingConfig, + ShardingDim, + ShardingSource, ShardingTransformInfo, SplitDimension, WeightShardingInfo, @@ -104,6 +106,7 @@ def check_and_apply(transform: ShardingTransformInfo) -> bool: info = TransformInfo( skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False ) + # exit() return gm, info @@ -137,10 +140,13 @@ class ShardingTransformConfig(TransformConfig): """Configuration for sharding transformations.""" simple_shard_only: bool = Field(default=False) - use_sharding_from_factory: bool = Field(default=False) - support_partial_config: bool = Field(default=False) - # Which sharding families to run: any subset of {"tp", "ep", "bmm"} - sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"]) + sharding_source: List[ShardingSource] = Field( + default_factory=lambda: [ShardingSource.HEURISTIC] + ) + # Which sharding dimensions to run: any subset of {"tp", "ep", "bmm"} + sharding_dims: List[ShardingDim] = Field( + default_factory=lambda: [ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM] + ) @TransformRegistry.register("detect_sharding") @@ -176,7 +182,7 @@ def _apply( shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: local_rank, world_size = shared_config.local_rank, shared_config.world_size - world_size = 2 + # world_size = 2 if world_size < 2: ad_logger.info("Skipping sharding for single device") @@ -195,54 +201,41 @@ def _apply( else ShardingConfigSource.UNKNOWN ) sharding_config.simple_shard_only = self.config.simple_shard_only - sharding_config.support_partial_config = self.config.support_partial_config sharding_config.sharding_dims = self.config.sharding_dims - - sharding_config.use_sharding_from_factory = self.config.use_sharding_from_factory + sharding_config.sharding_source = self.config.sharding_source sharding_config.validate_config() - # sharding_config.predefined_config = predefined_config - if ( - sharding_config.use_sharding_from_factory - and len(sharding_config.get_predefined_config()) > 0 - ): - ad_logger.info("Applying sharding from config") - factory_info = detect_sharding_from_factory_config(gm, sharding_config) - return gm, factory_info - - ad_logger.info(f"Running autodeploy sharding heuristics: {sharding_config.sharding_dims}") - # run TP sharding across ranks - if "tp" in sharding_config.sharding_dims: - tp_info = detect_column_row_shard(gm, sharding_config) - else: - tp_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) - if "ssm" in sharding_config.sharding_dims: - ssm_info = detect_ssm_shard(gm, sharding_config) - else: - ssm_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) + for source in shared_config.sharding_config.sharding_source: + if source == ShardingSource.FACTORY: + if len(shared_config.sharding_config.get_predefined_config()) == 0: + ad_logger.warning( + "No factory config found. Skipping sharding from factory config" + ) + continue + ad_logger.info("Applying sharding from factory config") + info += detect_sharding_from_factory_config(gm, sharding_config) - # run EP sharding across ranks - if "ep" in sharding_config.sharding_dims: - ep_info = detect_ep_shard(gm, sharding_config) - else: - ep_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + elif source == ShardingSource.HEURISTIC: + ad_logger.info( + f"Running autodeploy sharding heuristics: {sharding_config.sharding_dims}" + ) + if ShardingDim.SSM in sharding_config.sharding_dims: + info += detect_ssm_shard(gm, sharding_config) - # run BMM sharding across ranks - if "bmm" in sharding_config.sharding_dims: - dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config) - else: - dp_bmm_info = TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) + # run TP sharding across ranks + if ShardingDim.TP in sharding_config.sharding_dims: + info += detect_column_row_shard(gm, sharding_config) + + # run EP sharding across ranks + if ShardingDim.EP in sharding_config.sharding_dims: + info += detect_ep_shard(gm, sharding_config) + + # run BMM sharding across ranks + if ShardingDim.BMM in sharding_config.sharding_dims: + info += detect_dp_bmm_shard(gm, sharding_config) - info = tp_info + ssm_info + ep_info + dp_bmm_info return gm, info @@ -295,9 +288,9 @@ def _process_ssm_sharding( "conv1d": split_sizes_1, } - ############################################################## - ############## update split nodes ############################ - ############################################################## + # ############################################################## + # ############## update split nodes ############################ + # ############################################################## split_args_0 = list(split_nodes[0].args) split_args_0[1] = [s // world_size for s in split_args_0[1]] split_args_1 = list(split_nodes[1].args) @@ -319,9 +312,9 @@ def _process_ssm_sharding( ) ) - ############################################################## - ############# update conv1d num output channels ############## - ############################################################## + # ############################################################## + # ############# update conv1d num output channels ############## + # ############################################################## conv1d_nodes = [ n for n in subgraph_nodes @@ -358,7 +351,6 @@ def _process_ssm_sharding( ############## shard the remaining weights ################### ############################################################## # Get all weight nodes in the subgraph except for out_proj (it has to be row-sharded) - # weight_nodes = [n for n in subgraph_nodes if is_op(n, [torch.ops.aten.get_attr])] weight_nodes = [n for n in subgraph_nodes if n.op == "get_attr" and "out_proj" not in n.target] for weight_node in weight_nodes: weight_key = weight_node.target @@ -732,6 +724,7 @@ def detect_ssm_shard( # find all ssm nodes in the graph ssm_nodes = filtered_nodes(gm.graph.nodes, ops=torch.ops.auto_deploy.torch_ssm) + ssm_nodes = list(ssm_nodes)[1:2] num_ssm_shards = 0 for ssm_node in ssm_nodes: # We assume that one ssm node defines a subgraph corresponding diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 1c76d602ee8..11a315ed2a8 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -2,7 +2,6 @@ import math import operator -import re from abc import ABC, abstractmethod from enum import Enum, IntEnum from functools import partial @@ -89,6 +88,8 @@ def _validate_sharded_shapes( for view_node in nodes_to_validate: if len(view_node.args) < 2: continue + if "sharded" in view_node.meta and view_node.meta["sharded"]: + continue view_shape = list(view_node.args[1]) if not isinstance(view_shape, list): continue @@ -97,6 +98,7 @@ def _validate_sharded_shapes( view_shape[2] = view_shape[2] // world_size args[1] = tuple(view_shape) view_node.args = tuple(args) + view_node.meta["sharded"] = True ad_logger.debug(f"\nUpdated view node {view_node} arguments to {view_node.args}") # if fused_weight_dims is provided, we need to update all split sizes @@ -127,7 +129,6 @@ def shard_weight_tensor( world_size: int, min_local_shape: int = 1, fused_weight_dims: Optional[list] = None, - custom_shard_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, requires_grad: bool = False, update_param: bool = True, ) -> Tuple[torch.Tensor, torch.Size]: @@ -170,18 +171,16 @@ def split_tensor( # Handle fused weights if fused_weight_dims is not None: - # Split fused weights, apply TP sharding to each, then concatenate back - # sharded_weight = torch.cat( - # [split_tensor(w) for w in torch.split(weight_tensor, fused_weight_dims, dim=dim)], - # dim=dim, - # ) - # Create a function that applies the same logic for loading def split_fused_tensor( t: torch.Tensor, fused_dims: list = fused_weight_dims, d: int = dim, ) -> torch.Tensor: + # dim_d = t.shape[d] + # num_parts = 1 + # part_size = dim_d // num_parts + # fused_dims = [part_size] * num_parts return torch.cat( [split_tensor(w) for w in torch.split(t, fused_dims, dim=d)], dim=d, @@ -189,7 +188,6 @@ def split_fused_tensor( f_split = split_fused_tensor else: - # sharded_weight = split_tensor(weight_tensor) f_split = split_tensor sharded_weight = f_split(weight_tensor) @@ -226,186 +224,6 @@ def get_all_weights_in_subgraph( return weight_nodes -def _insert_sharded_mamba( - gm: GraphModule, - entry_node: Node, - dim: int, - rank: int, - world_size: int, - add_dist: bool = False, - min_local_shape: int = 1, - weights_to_shard: Optional[list[str]] = None, - weight_shard_dims: Optional[Dict[str, int]] = None, - fused_weight_dims: Optional[Dict[str, list]] = None, - quantization_cb: Optional[ - Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] - ] = None, -) -> bool: - """ - To shard Mamba layer, first column-shard the first linear layer: entry_node, - then shard all remaining weight tensors found in the subgraph defined between - entry_node and the next successor linear node. - First, validate if this is indeed a mamba module: within the subgraph, - there should be an torch_ssm node and conv1d node. - - Args: - gm: GraphModule - entry_node: The first linear node of the Mamba layer - dim: Default shard dimension - rank: Current rank - world_size: Total number of ranks - add_dist: Whether to add distribution op after entry_node - min_local_shape: Minimum local shape constraint - weights_to_shard: Optional list of regex patterns to match weight names - weight_shard_dims: Optional dict mapping weight keys to their shard dimensions - fused_weight_dims: Optional dict mapping weight keys to their fused dimension lists - quantization_cb: Optional quantization callback - """ - # Find next linear node to define subgraph boundary - try: - next_lin_node, depth = bfs(entry_node, is_linear_op, include_root=False) - except RuntimeError: - ad_logger.warning("Could not find next linear node after entry_node for Mamba sharding") - return False - - # Get subgraph between entry_node and next linear node - subgraph_nodes = subgraph([entry_node], [next_lin_node]) - - ############################################################## - ########## validate if this is a valid Mamba module ########## - ############################################################## - # has_ssm = any(is_op(n, torch.ops.auto_deploy.mamba.torch_ssm_transform) for n in subgraph_nodes) - has_ssm = True - conv1d_nodes = [ - n - for n in subgraph_nodes - if is_op(n, [torch.ops.aten.conv1d, torch.ops.auto_deploy.torch_causal_conv1d]) - ] - if len(conv1d_nodes) != 1 or not has_ssm: - ad_logger.warning( - f"Subgraph does not contain exactly one conv1d node and torch_ssm_transform. " - f"Skipping Mamba sharding. conv1d_nodes={conv1d_nodes}, has_ssm={has_ssm}" - ) - return False - - ############################################################## - ########## infer split sizes for in_proj and conv1d ########## - ############################################################## - # in_proj and conv1d are most likely fused, followed up by split nodes. Infer split sizes: - if fused_weight_dims is None: - split_nodes = [ - n - for n in subgraph_nodes - if is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]) - ] - if len(split_nodes) != 2: - ad_logger.warning( - f"Subgraph does not contain exactly two split nodes. " - f"Skipping Mamba sharding. split_nodes={split_nodes}" - ) - return False - split_sizes_1 = split_nodes[0].args[1] - split_sizes_2 = split_nodes[1].args[1] - if split_sizes_1[1] != sum(split_sizes_2): - ad_logger.warning( - f"Split nodes have different sizes. " - f"Skipping Mamba sharding. split_sizes_1={split_sizes_1}, split_sizes_2={split_sizes_2}" - ) - return False - fused_weight_dims = { - "in_proj": split_sizes_1[0:1] + split_sizes_2 + split_sizes_1[2:], - "conv1d": split_sizes_2, - } - - ############################################################## - ############# update conv1d num output channels ############## - ############################################################## - conv1d_node = conv1d_nodes[0] - # conv1d_node last argument is the number of output channels. - # This one is also sharded, so we need to update this parameter - conv_args = list(conv1d_node.args) - conv_args[-1] = conv1d_node.args[-1] // world_size - conv1d_node.args = tuple(conv_args) - - # First, shard the entry_node (the first linear layer) - # Extract entry node's fused_weight_dims by matching weight name against patterns - entry_fused_dims = None - if fused_weight_dims: - entry_weight_key, _ = extract_param_names_from_node(entry_node) - for pattern, dims in fused_weight_dims.items(): - if re.search(pattern, entry_weight_key): - entry_fused_dims = dims - break - - ############################################################## - ####### shard the entry_node (the first linear layer) ######## - ############################################################## - _shard_parameter_node( - gm=gm, - node=entry_node, - dim=dim, - rank=rank, - world_size=world_size, - add_dist=add_dist, - min_local_shape=min_local_shape, - fused_weight_dims=entry_fused_dims, - quantization_cb=quantization_cb, - ) - - ############################################################## - ############## shard the remaining weights ################### - ############################################################## - # Get all weight nodes in the subgraph except for out_proj - weight_nodes = [ - n - for n in get_all_weights_in_subgraph([entry_node], [next_lin_node]) - if "out_proj" not in str(n) - ] - - # Shard remaining weights, such as conv1d or RMSNorm - for weight_node in weight_nodes: - weight_key = weight_node.target - - # Filter by regex patterns if provided - if weights_to_shard is not None: - if not any(pattern in weight_key for pattern in weights_to_shard): - continue - - # Determine shard dimension for this weight - shard_dim = weight_shard_dims.get(weight_key, dim) if weight_shard_dims else dim - - # Get the weight parameter - try: - weight_param = gm.get_parameter(weight_key) - except AttributeError: - ad_logger.debug(f"Could not get parameter for {weight_key}, skipping") - continue - - # Get fused dims for this weight if specified - fused_dims = None - for k, v in fused_weight_dims.items(): - if k in weight_key: - fused_dims = v - break - - # Shard the weight tensor (also updates the parameter in the module) - _, sharded_shape = shard_weight_tensor( - gm=gm, - weight_tensor=weight_param, - param_key=weight_key, - dim=shard_dim, - rank=rank, - world_size=world_size, - min_local_shape=min_local_shape, - fused_weight_dims=fused_dims, - ) - - ad_logger.debug( - f"Sharded weight {weight_key} on dim {shard_dim}: " - f"{weight_param.shape} -> {sharded_shape}" - ) - - def _shard_parameter_node( gm: GraphModule, node: Node, @@ -490,7 +308,8 @@ def _shard_parameter_node( # # # column shard with no gather: the output is sharded if not add_dist: - _validate_sharded_shapes(node, fused_weight_dims=fused_weight_dims, world_size=world_size) + # if is_linear_op(node): + # _validate_sharded_shapes(node, fused_weight_dims=fused_weight_dims, world_size=world_size) return # figure out the right dist op @@ -509,7 +328,10 @@ def _shard_parameter_node( def _update_node_args(node: Node, args: tuple) -> None: """Update the node's arguments with the new sharded arguments.""" + if "sharded" in node.meta and node.meta["sharded"]: + return node.args = args + node.meta["sharded"] = True ad_logger.debug( f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}." ) @@ -605,21 +427,6 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool: def apply(self, gm: GraphModule, node: Node) -> None: """Apply TP sharding transformation to the graph module.""" - - # if self.layer_type == LayerType.MAMBA: - # _insert_sharded_mamba( - # gm=gm, - # entry_node=node, - # dim=self.split_dim.value, - # rank=self.rank, - # world_size=self.world_size, - # add_dist=self.dist_op is not None, - # min_local_shape=self.min_local_shape, - # fused_weight_dims=self.fused_weight_dims - # if isinstance(self.fused_weight_dims, dict) - # else None, - # ) - # else: _shard_parameter_node( gm=gm, node=node, @@ -1189,6 +996,22 @@ def _resolve_ep_cls_from_node(node: Node) -> type[EPShardingInfo]: return EPShardingInfo +class ShardingSource(Enum): + """Enum for sharding source.""" + + HEURISTIC = "heuristic" + FACTORY = "factory" + + +class ShardingDim(Enum): + """Enum for sharding dimension.""" + + SSM = "ssm" + TP = "tp" + EP = "ep" + BMM = "bmm" + + class ShardingConfig(BaseModel): """Configuration for sharding the model.""" @@ -1197,8 +1020,10 @@ class ShardingConfig(BaseModel): world_size: int = Field(default=1) predefined_config: Optional[Dict[str, Any]] = None simple_shard_only: bool = Field(default=False) - use_sharding_from_factory: bool = False support_partial_config: bool = False + sharding_source: List[ShardingSource] = Field( + default_factory=lambda: [ShardingSource.HEURISTIC] + ) sharding_dims: List[str] = Field(default_factory=list) weight_sharding_transforms: List[WeightShardingInfo] = Field(default_factory=list) parameter_update_transforms: List[ParameterUpdateInfo] = Field(default_factory=list) From 31a1de810c0d6723b7f7d341e640a5bdcc30018f Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Mon, 27 Oct 2025 03:53:37 -0700 Subject: [PATCH 12/12] Working nemotron with new setup Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 6 +- .../auto_deploy/models/patches/nemotron_h.py | 6 +- .../auto_deploy/transform/library/sharding.py | 43 ++-- .../_torch/auto_deploy/utils/node_utils.py | 4 + .../auto_deploy/utils/sharding_utils.py | 214 ++++++++++++++++-- 5 files changed, 236 insertions(+), 37 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 7a177b0751c..2c2e6dc761d 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -75,9 +75,9 @@ transforms: detect_sharding: stage: sharding simple_shard_only: false - use_sharding_from_factory: false - support_partial_config: false - sharding_dims: ['ssm'] #, 'tp'] #, 'ep', 'bmm'] + sharding_source: ['heuristic'] # ,'heuristic'] + support_partial_config: true + sharding_dims: ['ssm', 'tp', 'ep', 'bmm'] requires_shape_prop: true # TODO: (hg) need to ensure run_shape_prop after sharding. sharding_transform_executor: diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py index 72ecd0945e2..4ed4d2b036e 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py @@ -118,9 +118,9 @@ def _set_sharding_config_patched(self, *args, **kwargs): self._sharding_config["tp_plan"] = { "in_proj": "mamba", "out_proj": "rowwise", - "up_proj": "colwise", - "down_proj": "rowwise", - "*": "gather", + # "up_proj": "colwise", + # "down_proj": "rowwise", + # "*": "gather", } diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 16aa2c3c612..25b2330c38a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -48,6 +48,7 @@ ShardingTransformInfo, SplitDimension, WeightShardingInfo, + get_all_weights_in_subgraph, ) from ..interface import ( BaseTransform, @@ -143,6 +144,7 @@ class ShardingTransformConfig(TransformConfig): sharding_source: List[ShardingSource] = Field( default_factory=lambda: [ShardingSource.HEURISTIC] ) + support_partial_config: bool = Field(default=False) # Which sharding dimensions to run: any subset of {"tp", "ep", "bmm"} sharding_dims: List[ShardingDim] = Field( default_factory=lambda: [ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM] @@ -201,15 +203,16 @@ def _apply( else ShardingConfigSource.UNKNOWN ) sharding_config.simple_shard_only = self.config.simple_shard_only + sharding_config.support_partial_config = self.config.support_partial_config sharding_config.sharding_dims = self.config.sharding_dims sharding_config.sharding_source = self.config.sharding_source sharding_config.validate_config() info = TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True) - for source in shared_config.sharding_config.sharding_source: + for source in sharding_config.sharding_source: if source == ShardingSource.FACTORY: - if len(shared_config.sharding_config.get_predefined_config()) == 0: + if len(sharding_config.get_predefined_config()) == 0: ad_logger.warning( "No factory config found. Skipping sharding from factory config" ) @@ -252,7 +255,7 @@ def _process_ssm_sharding( """ # Find next linear node to define subgraph boundary try: - out_proj_node, depth = bfs(entry_node, is_linear_op, include_root=False) + out_proj_node, _ = bfs(entry_node, is_linear_op, include_root=False) except RuntimeError: ad_logger.warning("Could not find next linear node after entry_node for Mamba sharding") return 0 @@ -288,9 +291,9 @@ def _process_ssm_sharding( "conv1d": split_sizes_1, } - # ############################################################## - # ############## update split nodes ############################ - # ############################################################## + # # ############################################################## + # # ############## update split nodes ############################ + # # ############################################################## split_args_0 = list(split_nodes[0].args) split_args_0[1] = [s // world_size for s in split_args_0[1]] split_args_1 = list(split_nodes[1].args) @@ -332,9 +335,9 @@ def _process_ssm_sharding( ) ) - ############################################################## - ####### shard the entry_node (the first linear layer) ######## - ############################################################## + # ############################################################## + # ####### shard the entry_node (the first linear layer) ######## + # ############################################################## sharding_config.add( WeightShardingInfo.from_node( entry_node, @@ -347,11 +350,15 @@ def _process_ssm_sharding( ) ) - ############################################################## - ############## shard the remaining weights ################### - ############################################################## - # Get all weight nodes in the subgraph except for out_proj (it has to be row-sharded) - weight_nodes = [n for n in subgraph_nodes if n.op == "get_attr" and "out_proj" not in n.target] + # ############################################################## + # ############## shard the remaining weights ################### + # ############################################################## + # # Get all weight nodes in the subgraph except for out_proj (it has to be row-sharded) + weight_nodes = [ + n + for n in get_all_weights_in_subgraph([entry_node], [out_proj_node]) + if "out_proj" not in str(n) + ] for weight_node in weight_nodes: weight_key = weight_node.target # Get the weight parameter @@ -381,9 +388,9 @@ def _process_ssm_sharding( ) ) - ############################################################## - ############## update the view and reshape nodes ############# - ############################################################## + # ############################################################## + # ############## update the view and reshape nodes ############# + # ############################################################## nodes_to_validate = [ n for n in subgraph_nodes if is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]) ] @@ -724,7 +731,7 @@ def detect_ssm_shard( # find all ssm nodes in the graph ssm_nodes = filtered_nodes(gm.graph.nodes, ops=torch.ops.auto_deploy.torch_ssm) - ssm_nodes = list(ssm_nodes)[1:2] + # ssm_nodes = list(ssm_nodes)[1:2] num_ssm_shards = 0 for ssm_node in ssm_nodes: # We assume that one ssm node defines a subgraph corresponding diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 877256bc608..c7e2b637f8f 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -379,6 +379,10 @@ def identify_regions_between_residuals(gm: GraphModule) -> List[Node]: return boundary_nodes +def identify_layer_subgraphs(gm: GraphModule) -> None: + pass + + def bfs( node: Node, target: Callable, diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 11a315ed2a8..1ee817b0d30 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -2,6 +2,7 @@ import math import operator +import re from abc import ABC, abstractmethod from enum import Enum, IntEnum from functools import partial @@ -224,6 +225,176 @@ def get_all_weights_in_subgraph( return weight_nodes +def _insert_sharded_mamba( + gm: GraphModule, + entry_node: Node, + dim: int, + rank: int, + world_size: int, + add_dist: bool = False, + min_local_shape: int = 1, + weights_to_shard: Optional[list[str]] = None, + weight_shard_dims: Optional[Dict[str, int]] = None, + fused_weight_dims: Optional[Dict[str, list]] = None, + quantization_cb: Optional[ + Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] + ] = None, +) -> bool: + """ + To shard Mamba layer, first column-shard the first linear layer: entry_node, + then shard all remaining weight tensors found in the subgraph defined between + entry_node and the next successor linear node. + First, validate if this is indeed a mamba module: within the subgraph, + there should be an torch_ssm node and conv1d node. + + Args: + gm: GraphModule + entry_node: The first linear node of the Mamba layer + dim: Default shard dimension + rank: Current rank + world_size: Total number of ranks + add_dist: Whether to add distribution op after entry_node + min_local_shape: Minimum local shape constraint + weights_to_shard: Optional list of regex patterns to match weight names + weight_shard_dims: Optional dict mapping weight keys to their shard dimensions + fused_weight_dims: Optional dict mapping weight keys to their fused dimension lists + quantization_cb: Optional quantization callback + """ + # Find next linear node to define subgraph boundary + try: + next_lin_node, depth = bfs(entry_node, is_linear_op, include_root=False) + except RuntimeError: + ad_logger.warning("Could not find next linear node after entry_node for Mamba sharding") + return False + + # Get subgraph between entry_node and next linear node + subgraph_nodes = subgraph([entry_node], [next_lin_node]) + + ############################################################## + ########## validate if this is a valid Mamba module ########## + ############################################################## + # has_ssm = any(is_op(n, torch.ops.auto_deploy.mamba.torch_ssm_transform) for n in subgraph_nodes) + has_ssm = True + conv1d_nodes = [ + n + for n in subgraph_nodes + if is_op(n, [torch.ops.aten.conv1d, torch.ops.auto_deploy.torch_causal_conv1d]) + ] + if len(conv1d_nodes) != 1 or not has_ssm: + ad_logger.warning( + f"Subgraph does not contain exactly one conv1d node and torch_ssm_transform. " + f"Skipping Mamba sharding. conv1d_nodes={conv1d_nodes}, has_ssm={has_ssm}" + ) + return False + + ############################################################## + ########## infer split sizes for in_proj and conv1d ########## + ############################################################## + # in_proj and conv1d are most likely fused, followed up by split nodes. Infer split sizes: + if fused_weight_dims is None: + split_nodes = [ + n + for n in subgraph_nodes + if is_op(n, [torch.ops.aten.split, torch.ops.aten.split_with_sizes]) + ] + if len(split_nodes) != 2: + ad_logger.warning( + f"Subgraph does not contain exactly two split nodes. " + f"Skipping Mamba sharding. split_nodes={split_nodes}" + ) + return False + split_sizes_1 = split_nodes[0].args[1] + split_sizes_2 = split_nodes[1].args[1] + if split_sizes_1[1] != sum(split_sizes_2): + ad_logger.warning( + f"Split nodes have different sizes. " + f"Skipping Mamba sharding. split_sizes_1={split_sizes_1}, split_sizes_2={split_sizes_2}" + ) + return False + fused_weight_dims = { + "in_proj": split_sizes_1[0:1] + split_sizes_2 + split_sizes_1[2:], + "conv1d": split_sizes_2, + } + + conv1d_node = conv1d_nodes[0] + # conv1d_node last argument is the number of output channels. + # This one is also sharded, so we need to update this parameter + conv_args = list(conv1d_node.args) + conv_args[-1] = conv1d_node.args[-1] // world_size + conv1d_node.args = tuple(conv_args) + + # First, shard the entry_node (the first linear layer) + # Extract entry node's fused_weight_dims by matching weight name against patterns + entry_fused_dims = None + if fused_weight_dims: + entry_weight_key, _ = extract_param_names_from_node(entry_node) + for pattern, dims in fused_weight_dims.items(): + if re.search(pattern, entry_weight_key): + entry_fused_dims = dims + break + + _shard_parameter_node( + gm=gm, + node=entry_node, + dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + add_dist=False, + min_local_shape=min_local_shape, + fused_weight_dims=entry_fused_dims, + ) + + # Get all weight nodes in the subgraph except for out_proj + weight_nodes = [ + n + for n in get_all_weights_in_subgraph([entry_node], [next_lin_node]) + if "out_proj" not in str(n) + ] + + # Shard remaining weights, such as conv1d or RMSNorm + for weight_node in weight_nodes: + weight_key = weight_node.target + + # Filter by regex patterns if provided + if weights_to_shard is not None: + if not any(pattern in weight_key for pattern in weights_to_shard): + continue + + # Determine shard dimension for this weight + shard_dim = weight_shard_dims.get(weight_key, dim) if weight_shard_dims else dim + + # Get the weight parameter + try: + weight_param = gm.get_parameter(weight_key) + except AttributeError: + ad_logger.debug(f"Could not get parameter for {weight_key}, skipping") + continue + + # Get fused dims for this weight if specified + fused_dims = None + for k, v in fused_weight_dims.items(): + if k in weight_key: + fused_dims = v + break + + # Shard the weight tensor (also updates the parameter in the module) + _, sharded_shape = shard_weight_tensor( + gm=gm, + weight_tensor=weight_param, + param_key=weight_key, + dim=shard_dim, + rank=rank, + world_size=world_size, + min_local_shape=min_local_shape, + fused_weight_dims=fused_dims, + ) + + ad_logger.debug( + f"Sharded weight {weight_key} on dim {shard_dim}: " + f"{weight_param.shape} -> {sharded_shape}" + ) + + def _shard_parameter_node( gm: GraphModule, node: Node, @@ -280,7 +451,7 @@ def _shard_parameter_node( rank=rank, world_size=world_size, min_local_shape=min_local_shape, - fused_weight_dims=None, + fused_weight_dims=fused_weight_dims, ) elif bias_key is not None and rank != world_size - 1: # update the bias for dim 1 --> in this case only the last rank gets the bias to avoid @@ -308,8 +479,10 @@ def _shard_parameter_node( # # # column shard with no gather: the output is sharded if not add_dist: - # if is_linear_op(node): - # _validate_sharded_shapes(node, fused_weight_dims=fused_weight_dims, world_size=world_size) + if is_linear_op(node): + _validate_sharded_shapes( + node, fused_weight_dims=fused_weight_dims, world_size=world_size + ) return # figure out the right dist op @@ -397,6 +570,7 @@ class WeightShardingInfo(ShardingTransformInfo): split_dim: SplitDimension dist_op: Optional[Literal["all_reduce", "all_gather"]] = None min_local_shape: int = 1 + layer_type: LayerType = LayerType.MLP # used for TP sharding of fused weights fused_weight_dims: Optional[list] = None @@ -427,16 +601,30 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool: def apply(self, gm: GraphModule, node: Node) -> None: """Apply TP sharding transformation to the graph module.""" - _shard_parameter_node( - gm=gm, - node=node, - dim=self.split_dim.value, - rank=self.rank, - world_size=self.world_size, - add_dist=self.dist_op is not None, - min_local_shape=self.min_local_shape, - fused_weight_dims=self.fused_weight_dims, - ) + if self.layer_type == LayerType.MAMBA: + _insert_sharded_mamba( + gm=gm, + entry_node=node, + dim=self.split_dim.value, + rank=self.rank, + world_size=self.world_size, + add_dist=self.dist_op is not None, + min_local_shape=self.min_local_shape, + fused_weight_dims=self.fused_weight_dims + if isinstance(self.fused_weight_dims, dict) + else None, + ) + else: + _shard_parameter_node( + gm=gm, + node=node, + dim=self.split_dim.value, + rank=self.rank, + world_size=self.world_size, + add_dist=self.dist_op is not None, + min_local_shape=self.min_local_shape, + fused_weight_dims=self.fused_weight_dims, + ) class ParameterUpdateInfo(ShardingTransformInfo):