diff --git a/tests/models/jax/common/moe/test_deepseek_moe.py b/tests/models/jax/common/moe/test_deepseek_moe.py new file mode 100644 index 000000000..6983649b4 --- /dev/null +++ b/tests/models/jax/common/moe/test_deepseek_moe.py @@ -0,0 +1,221 @@ +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +from jax.sharding import Mesh, PartitionSpec + +from tpu_commons.models.jax.common.moe.deepseek_moe import (DeepSeekV3Router, + SparseMoE) + + +class TestDeepSeekV3Router(unittest.TestCase): + + def setUp(self): + self.cpu_mesh = Mesh(jax.devices('cpu'), axis_names=('data', )) + + def test_get_topk_indices_single_group(self): + """Test get_topk_indices with single expert group.""" + with jax.set_mesh(self.cpu_mesh): + router = DeepSeekV3Router(random_init=True, + hidden_size=512, + num_experts=4, + num_experts_per_tok=2, + n_groups=1, + topk_groups=1, + norm_topk_prob=True, + routed_scaling_factor=1.0, + dtype=jnp.bfloat16, + rngs=nnx.Rngs(42)) + router.bias_E = jnp.zeros((4, )) + + scores = jnp.array([[0.1, 0.3, 0.2, 0.4]]) # shape: (1, 4) + indices = router.get_topk_indices(scores) + + # Should return indices of top 2 experts + expected_indices = jnp.array([[3, + 1]]) # experts with scores 0.4, 0.3 + self.assertTrue(jnp.array_equal(indices, expected_indices)) + + def test_get_topk_indices_2_groups(self): + """Test get_topk_indices with 2 expert groups.""" + with jax.set_mesh(self.cpu_mesh): + router = DeepSeekV3Router(random_init=True, + hidden_size=512, + num_experts=4, + num_experts_per_tok=2, + n_groups=2, + topk_groups=1, + norm_topk_prob=True, + routed_scaling_factor=1.0, + dtype=jnp.bfloat16, + rngs=nnx.Rngs(42)) + router.bias_E = jnp.zeros((4, )) + + # 4 experts, 2 groups, 2 experts per group + scores = jnp.array([[[0.1, 0.3, 0.2, 0.4]]]) # shape: (1, 1, 4) + indices = router.get_topk_indices(scores) + + # Should return indices of top 2 experts + expected_indices = jnp.array([[[3, 2]]]) + self.assertTrue(jnp.array_equal(indices, expected_indices)) + + def test_router_e2e(self): + with jax.set_mesh(self.cpu_mesh): + router = DeepSeekV3Router(random_init=True, + hidden_size=512, + num_experts=8, + num_experts_per_tok=2, + n_groups=2, + topk_groups=1, + norm_topk_prob=True, + routed_scaling_factor=1.0, + dtype=jnp.bfloat16, + rngs=nnx.Rngs(42)) + x = jnp.ones((2, 512)) + weights, indices = router(x) + self.assertEqual(weights.shape, (2, 2)) + self.assertEqual(indices.shape, (2, 2)) + + +class TestSparseMoE(unittest.TestCase): + + def setUp(self): + """Set up a multi-device mesh and a sample MoE layer for testing.""" + devices = jax.devices() + self.device_count = len(devices) + if self.device_count < 8: + self.skipTest("This test requires at least 8 simulated devices.") + + # This mesh will have a 'model' axis for expert parallelism + mesh_shape = (self.device_count, 1) + device_mesh_array = np.array(devices).reshape(mesh_shape) + + # Define the axis names + axis_names = ('model', 'data') + + # Create the 2D mesh + self.mesh = Mesh(device_mesh_array, axis_names=axis_names) + + # --- Model Configuration --- + self.B, self.S, self.D = 2, 4, 16 # Batch, Sequence, Hidden Dim + self.E, self.K = 16, 8 # Num Experts, Experts per Token + self.moe_intermediate_size = 32 # FFN Dim + self.num_expert_parallelism = 8 # Shard experts across 8 devices + + self.key = jax.random.PRNGKey(42) + self.x = jax.random.normal(self.key, (self.B * self.S, self.D), + dtype=jnp.bfloat16) + + # --- Instantiate MoE Layer --- + # We need to do this inside the mesh context + with self.mesh: + router = DeepSeekV3Router(hidden_size=self.D, + num_experts=self.E, + num_experts_per_tok=self.K, + n_groups=1, + topk_groups=1, + norm_topk_prob=False, + routed_scaling_factor=1.0, + dtype=jnp.bfloat16, + rngs=nnx.Rngs(self.key), + ed_sharding=PartitionSpec(), + e_sharding=PartitionSpec(), + activation_ffw_td=PartitionSpec( + 'data', None)) + # Instantiation updated to match user's code snippet + self.moe = SparseMoE( + hidden_size=self.D, + intermediate_size_moe=self.moe_intermediate_size, + num_local_experts=self.E, + hidden_act="silu", + num_experts_per_tok=self.K, + router=router, + dtype=jnp.bfloat16, + rngs=nnx.Rngs(self.key), + mesh=self.mesh, + apply_expert_weight_before_computation=False, + + # Sharding specs updated based on user's snippet + edf_sharding=PartitionSpec('model', None, None), + efd_sharding=PartitionSpec('model', None, None), + activation_ffw_ted=PartitionSpec('data', None), + activation_ffw_td=PartitionSpec( + 'data', None) # Activations are replicated + ) + + def test_token_replicated_expert_parallel_fwd(self): + """ + Validates the MoE forward pass against a simple, dense equivalent. + This specifically tests the is_batch_sharded_by_expert=False path. + """ + # --- 1. Get the ACTUAL output from the complex distributed MoE layer --- + # The __call__ method will trigger the shard_map, which requires the mesh context. + with self.mesh: + actual_output = self.moe(self.x) + + # --- 2. Calculate the EXPECTED output using a simple, sequential process --- + # This serves as the "ground truth". + + # Get router decisions (router params are replicated, so this is fine) + router_weights, selected_experts = self.moe.router(self.x) + + # Gather the full, unsharded weights from all devices --- + # .value on a sharded param gives the *local* shard. + # jax.device_get() retrieves the *full* GlobalDeviceArray to the host. + gating_kernel_full = jax.device_get(self.moe.kernel_gating_EDF.value) + up_proj_kernel_full = jax.device_get(self.moe.kernel_up_proj_EDF.value) + down_proj_kernel_full = jax.device_get( + self.moe.kernel_down_proj_EFD.value) + + # Check that we really got the full weights + self.assertEqual(gating_kernel_full.shape, + (self.E, self.D, self.moe_intermediate_size)) + + # Flatten inputs for easier iteration + flat_x = self.x.reshape(self.B * self.S, self.D) + flat_weights = router_weights.reshape(self.B * self.S, self.K) + flat_experts = selected_experts.reshape(self.B * self.S, self.K) + + expected_output = jnp.zeros_like(flat_x) + + # Manually apply each expert to each token sequentially + for i in range(self.B * self.S): # For each token + token_input = flat_x[i] + combined_expert_output = jnp.zeros(self.D, dtype=jnp.bfloat16) + + for k in range(self.K): # For each chosen expert for that token + expert_idx = flat_experts[i, k] + weight = flat_weights[i, k] + + # Get kernels from the *full* gathered arrays --- + gating_kernel = gating_kernel_full[expert_idx] + up_proj_kernel = up_proj_kernel_full[expert_idx] + down_proj_kernel = down_proj_kernel_full[expert_idx] + + # Perform the expert computation (dense matmuls) + gating_proj = jnp.dot(token_input, gating_kernel) + up_proj = jnp.dot(token_input, up_proj_kernel) + + # Note: Assuming 'silu' activation as specified in MoE init + fused = nnx.silu(gating_proj) * up_proj + + expert_output = jnp.dot(fused, down_proj_kernel) + + # Apply router weight after computation (matches implementation) + combined_expert_output += weight * expert_output + + expected_output = expected_output.at[i].set(combined_expert_output) + + expected_output = expected_output.reshape(self.B * self.S, self.D) + + # --- 3. Compare the results --- + self.assertTrue( + jnp.allclose(actual_output, expected_output, atol=1e-2, rtol=1e-2), + f"The output of the distributed MoE does not match the dense equivalent.\n" + f"Actual:\n{actual_output}\n" + f"Expected:\n{expected_output}") + print( + "\nāœ… Test Passed: Distributed MoE output matches the dense ground truth." + ) diff --git a/tpu_commons/models/jax/common/moe/deepseek_moe.py b/tpu_commons/models/jax/common/moe/deepseek_moe.py new file mode 100644 index 000000000..5da2a2b61 --- /dev/null +++ b/tpu_commons/models/jax/common/moe/deepseek_moe.py @@ -0,0 +1,622 @@ +import enum +from dataclasses import InitVar, dataclass +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +from flax import nnx +from flax.typing import Sharding +from jax.sharding import PartitionSpec +from jaxtyping import Float + +from tpu_commons.models.jax.common.base import create_param +from tpu_commons.models.jax.common.layers import FlaxUtils +from tpu_commons.models.jax.common.moe.moe import MoE + +modeling_flax_utils = FlaxUtils() + + +@dataclass +class DeepSeekV3Router(nnx.Module): + """Router module for Mixture-of-Experts (MoE) layers. + + This module determines which experts each token should be routed to based on the input. + + """ + + hidden_size: int + num_experts: int + num_experts_per_tok: int + n_groups: int + topk_groups: int + norm_topk_prob: bool + routed_scaling_factor: float + dtype: jnp.dtype + rngs: InitVar[nnx.Rngs] + + # Sharding Attributes + activation_ffw_td: Sharding = () + ed_sharding: Sharding = () + e_sharding: Sharding = () + + random_init: bool = False + + router_bias_dtype: jnp.dtype = jnp.float32 + + def get_topk_indices(self, scores_TE: Float) -> Float: + """Get the topk indices of the scores. + + Args: + scores_TE: The scores to get the topk indices of. Shape (sequence, num_experts). + + Returns: + The topk indices of the scores. Shape (sequence, num_experts_per_tok). + """ + + scores_TE = scores_TE + self.bias_E + if self.n_groups > 1: + experts_per_group = self.num_experts // self.n_groups + group_scores_TGM = jnp.reshape( + scores_TE, (-1, self.n_groups, experts_per_group)) + group_scores_TG2 = jax.lax.top_k(group_scores_TGM, k=2)[0] + group_scores_TG = jnp.sum(group_scores_TG2, axis=-1) + indices = jax.lax.top_k(group_scores_TG, k=self.topk_groups)[1] + + mask_TG = jnp.any(jnp.arange( + self.n_groups)[:, None] == indices[..., None, :], + axis=-1) + mask_TE = jnp.repeat(mask_TG, + scores_TE.shape[-1] // mask_TG.shape[-1], -1) + scores_TE = jnp.where(mask_TE, scores_TE, 0.0) + + indices_TX = jax.lax.top_k(scores_TE, k=self.num_experts_per_tok)[1] + + return indices_TX + + def __call__(self, x_TD: Float) -> Tuple[Float, Float]: + """Routes tokens to top k experts. + + Args: + x_TD: Input array of shape (sequence, d_model). + + Returns: + A tuple containing: + - weights: Normalized weights for selected experts, shape (sequence, num_experts_per_tok). + - indices: Indices of selected experts, shape (sequence, num_experts_per_tok). + """ + x_TD = jnp.asarray(x_TD, self.dtype) + x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td) + + scores_TE = jnp.einsum("TD,DE -> TE", x_TD, self.kernel_DE.value) + scores_TE = nnx.sigmoid(scores_TE) + + original_scores_TE = scores_TE + topk_indices_TX = self.get_topk_indices(scores_TE) + weights_TX = jnp.take_along_axis(original_scores_TE, + topk_indices_TX, + axis=-1) + + if self.norm_topk_prob: + weights_TX /= jnp.sum(weights_TX, axis=-1)[..., None] + 1e-20 + + weights_TX *= self.routed_scaling_factor + + return weights_TX, topk_indices_TX + + def __post_init__(self, rngs: nnx.Rngs): + """Generates the router kernel (weights and bias) for routing.""" + D = self.hidden_size + E = self.num_experts + self.kernel_DE = create_param(rngs, + shape=(D, E), + dtype=self.dtype, + sharding=self.ed_sharding, + random_init=self.random_init) + self.bias_E = create_param(rngs, + shape=(E, ), + dtype=self.router_bias_dtype, + sharding=self.e_sharding, + random_init=self.random_init) + + +@dataclass(kw_only=True) +class SparseMoE(MoE): + """Mixture-of-Experts (MoE) Routed MLP Layer. + + This module implements a Sparse MoE layer with a router and multiple expert MLPs. + + Attributes: + num_experts_per_tok: The number of experts each token is routed to. + tile_size: A tuple (batch, activation_dim, weight_dim) for GMM tiling. + use_megablox: If True, uses the MegaBlox GMM kernel. + mesh: The device mesh. + # TODO: need to redesign this I/O for parallelism + num_expert_parallelism: The size of the 'expert' mesh dimension. + # TODO: determine if we get it from external or extrat it in MoE class + is_batch_sharded_by_expert: True if batch is sharded over 'expert' dim. + """ + edf_sharding: Sharding + efd_sharding: Sharding + num_experts_per_tok: int + #TODO: tile size is (tile_batch_seq, tile_activation_dim, tile_weight_dim,) from MaxText + tile_size: tuple[int, int, int] = (128, 64, 128) + use_megablox: bool = False + mesh: jax.sharding.Mesh + + def __post_init__(self, rngs: nnx.Rngs): + + D = self.hidden_size + F = self.intermediate_size_moe + # shape_gating = (D, self.num_local_experts, F) + # shape_up = (D, self.num_local_experts, F) + # shape_down = (F, self.num_local_experts,D) + shape_gating = (self.num_local_experts, D, F) + shape_up = (self.num_local_experts, D, F) + shape_down = (self.num_local_experts, F, D) + + self.kernel_gating_EDF = create_param(rngs, + shape=shape_gating, + dtype=self.dtype, + sharding=self.edf_sharding, + random_init=self.random_init) + self.kernel_up_proj_EDF = create_param(rngs, + shape=shape_up, + dtype=self.dtype, + sharding=self.edf_sharding, + random_init=self.random_init) + self.kernel_down_proj_EFD = create_param(rngs, + shape=shape_down, + dtype=self.dtype, + sharding=self.efd_sharding, + random_init=self.random_init) + + # Derive the expert sharding + self.expert_axis_name = self.edf_sharding[0] + if self.expert_axis_name is None: + self.num_expert_parallelism = 1 + else: + self.num_expert_parallelism = self.mesh.shape[ + self.expert_axis_name] + + # Derive if data is sharded by expert + self.data_axis_name = self.activation_ffw_td[0] + self.is_batch_sharded_by_expert = ( + self.expert_axis_name is not None) and (self.expert_axis_name + == self.data_axis_name) + + def _sort_activations(self, inputs: jax.Array, + sort_indices: jax.Array) -> jax.Array: + """Sorts activations(inputs) by `sort_indices` for the forward pass.""" + return inputs[sort_indices, ...] + + @staticmethod + def get_all_to_all_params( + all_shards_group_sizes, + shard_id, + num_expert_parallelism, + is_batch_sharded=True, + ): + """Generates params for ragged_all_to_all communication.""" + + class TransformStrategy(enum.Enum): + INPUT_OFFSET = enum.auto() + SEND_SIZE = enum.auto() + OUTPUT_OFFSET = enum.auto() + RECV_SIZE = enum.auto() + + def transform_array(input_array, shard_id, strategy, is_batch_sharded): + if is_batch_sharded: + if strategy == TransformStrategy.INPUT_OFFSET: + local_array = input_array[shard_id] + return jnp.concatenate( + (jnp.array([0]), jnp.cumsum(local_array)[:-1])) + elif strategy == TransformStrategy.SEND_SIZE: + return input_array[shard_id] + elif strategy == TransformStrategy.OUTPUT_OFFSET: + zero_row = jnp.zeros((1, ) + input_array.shape[1:], + dtype=input_array.dtype) + array_with_zeros = jnp.concatenate((zero_row, input_array), + axis=0) + cumulated_array = jnp.cumsum(array_with_zeros, + axis=0, + dtype=input_array.dtype) + return cumulated_array[shard_id] + elif strategy == TransformStrategy.RECV_SIZE: + return input_array[:, shard_id] + else: + raise ValueError( + f"Unknown transform array strategy: {strategy}") + else: + if strategy == TransformStrategy.INPUT_OFFSET: + return jnp.zeros(num_expert_parallelism, + dtype=input_array.dtype) + elif strategy == TransformStrategy.SEND_SIZE: + return jnp.repeat(input_array[shard_id], + num_expert_parallelism) + elif strategy == TransformStrategy.OUTPUT_OFFSET: + output_offset = jnp.concatenate( + (jnp.array([0]), + jnp.cumsum(input_array[:-1])))[shard_id] + return jnp.repeat(output_offset, num_expert_parallelism) + elif strategy == TransformStrategy.RECV_SIZE: + return input_array + else: + raise ValueError( + f"Unknown transform array strategy: {strategy}") + + input_offsets = transform_array(all_shards_group_sizes, shard_id, + TransformStrategy.INPUT_OFFSET, + is_batch_sharded) + send_sizes = transform_array(all_shards_group_sizes, shard_id, + TransformStrategy.SEND_SIZE, + is_batch_sharded) + output_offsets = transform_array(all_shards_group_sizes, shard_id, + TransformStrategy.OUTPUT_OFFSET, + is_batch_sharded) + recv_sizes = transform_array(all_shards_group_sizes, shard_id, + TransformStrategy.RECV_SIZE, + is_batch_sharded) + return input_offsets, send_sizes, output_offsets, recv_sizes + + def _local_permute( + self, + inputs, + global_group_sizes, + local_expert_size, + shard_index, + is_offset=False, + global_sorted_experts=None, + ): + """Permutes tokens locally within an expert shard.""" + # global_group_sizes: (tokens parallelism, num_total_experts) + # all_shard_local_sizes: (tokens parallelism, num local experts in the shard) + all_shard_local_sizes = jax.lax.dynamic_slice_in_dim( + global_group_sizes, + shard_index * local_expert_size, + local_expert_size, + axis=1, + ) + local_sizes = all_shard_local_sizes.reshape(-1) + + # local_group_size: (tokens parallelism, ) + local_group_size = jnp.sum(all_shard_local_sizes, axis=0) + + # When token replicated in devices + if is_offset: + global_sorted_shard_assignments = jnp.floor_divide( + global_sorted_experts, local_expert_size) + expert_indices = jnp.where( + global_sorted_shard_assignments == shard_index, + jnp.mod(global_sorted_experts, local_expert_size), + local_expert_size, + ) + + # When token sharded in devices + else: + base_indices = jnp.mod(jnp.arange(local_sizes.shape[0]), + local_expert_size) + expert_indices = jnp.repeat(base_indices, + local_sizes, + total_repeat_length=inputs.shape[0]) + + sorted_indices = jnp.argsort(expert_indices) + # sort the inputs based on the local expert_indices + sorted_inputs = self._sort_activations(inputs, sorted_indices) + # sortted local expert id from 0 to local expert size + sorted_experts_ids = expert_indices[sorted_indices] + return ( + sorted_inputs, + sorted_indices, + local_group_size, + sorted_experts_ids, + ) + + def _permute(self, inputs_TD: Float, selected_experts_TX: jax.Array): + """Global permute: Sorts tokens by assigned expert.""" + # suffix t = T * X = total_assignments for the local tokens(T) on this device. + total_tokens = inputs_TD.shape[0] + flat_expert_indices = selected_experts_TX.flatten() + sort_indices_t = jnp.argsort(flat_expert_indices) + + replicated_inputs_tD = jnp.repeat(inputs_TD, + self.num_experts_per_tok, + axis=0) + sorted_inputs_tD = self._sort_activations(replicated_inputs_tD, + sort_indices_t) + + # number of tokens assigned to each expert + group_sizes_E = jnp.bincount(flat_expert_indices, + length=self.num_local_experts) + + expert_ids = jnp.arange(self.num_local_experts) + total_assignments = total_tokens * self.num_experts_per_tok + sorted_expert_assignments_t = jnp.repeat( + expert_ids, + repeats=group_sizes_E, + total_repeat_length=total_assignments) + + return ( + sorted_inputs_tD, + sort_indices_t, + group_sizes_E, + sorted_expert_assignments_t, + ) + + def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array, + router_weights_TX: jax.Array): + """Unsorts tokens to their original order and combines expert outputs with router's weight.""" + with jax.named_scope("unpermute"): + unsorted_tokens_tD = self._sort_activations( + processed_tokens, jnp.argsort(sort_indices)) + D = unsorted_tokens_tD.shape[1] + reshaped_tokens_TXD = unsorted_tokens_tD.reshape( + -1, self.num_experts_per_tok, D) + # jax.debug.print( + # "āœ… reshaped_tokens_TXD on device: reshaped_tokens_TXD[5]={t}", + # t=reshaped_tokens_TXD[5, 0,:5] + # ) + # jax.debug.print( + # "āœ… router_weights_TX on device: router_weights_TX={t}", + # t=router_weights_TX[5, :] + # ) + with jax.named_scope("combine_weights"): + output_TD = jnp.einsum( + "TXD,TX -> TD", + reshaped_tokens_TXD.astype(self.dtype), + router_weights_TX.astype(self.dtype), + ) + + return output_TD.astype(self.dtype) + + def _gmm(self, inputs, kernel, group_sizes): + """Performs Grouped Matrix Multiply.""" + jax.config.update("jax_ragged_dot_use_ragged_dot_instruction", True) + num_rows = inputs.shape[0] + pad_amount = (self.tile_size[0] - + num_rows % self.tile_size[0]) % self.tile_size[0] + if pad_amount > 0: + inputs = jnp.pad(inputs, ((0, pad_amount), (0, 0))) + + if self.use_megablox: + #TODO: megablox is used in MaxText, keep a placeholder here for future implement + raise NotImplementedError( + "MegaBlox kernel call is not implemented.") + else: + + output = jax.lax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=group_sizes, + preferred_element_type=self.dtype, + ) + + if pad_amount > 0: + output = output[:num_rows, :] + return output + + @staticmethod + def _distributed_sparse_moe_fwd( + self, + x_TD: jax.Array, + router_weights_TX: jax.Array, + selected_experts_TX: jax.Array, + kernel_gating: jax.Array, + kernel_up_proj: jax.Array, + kernel_down_proj: jax.Array, + ): + """ + The sparse MoE forward pass with fully distributed logic. + This assumes it is running within a distributed TPU. + """ + + # 1. Global Permute, perpute all tokens across shards + ( + sorted_inputs, + global_sort_indices, + global_group_sizes, + global_sorted_experts, + ) = self._permute(x_TD, selected_experts_TX) + + # TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis + # or we sould derive it from the model init + + local_expert_size = self.num_local_experts // self.num_expert_parallelism + + #if self.num_expert_parallelism > 1: + if self.expert_axis_name: + expert_shard_id = jax.lax.axis_index(self.expert_axis_name) + if self.is_batch_sharded_by_expert: + # When token sharded in devices + # In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name + + # 2a. Send Tokens To Experts (All-to-All) + # Gather group sizes from all data shards + # all_shards_group_sizes: (data parallelism = expert parallelism, number of total experts ) + all_shards_group_sizes = jax.lax.all_gather( + global_group_sizes, axis_name=self.data_axis_name) + + # all_shards_group_sizes_per_expert_shard[i][j] = # tokens on shard[i] to be sent to expert shard[j] + all_shards_group_sizes_per_expert_shard = jnp.sum( + all_shards_group_sizes.reshape( + self.num_expert_parallelism, # data parallelism + self.num_expert_parallelism, # expert parallelism + local_expert_size # Experts per shard + ), + axis=2) + input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params( + all_shards_group_sizes_per_expert_shard, expert_shard_id, + self.num_expert_parallelism) + # Estimate buffer size + local_total_assignments = x_TD.shape[ + 0] * self.num_experts_per_tok + global_total_assignments = local_total_assignments * self.num_expert_parallelism + output_shape_est = jnp.zeros( + (global_total_assignments, self.hidden_size), + dtype=sorted_inputs.dtype) + + inputs_after_all2all = jax.lax.ragged_all_to_all( + sorted_inputs, + output_shape_est, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self.expert_axis_name) + + # 3a. Local Permute + # Get full group sizes from all shards + full_global_group_sizes = jax.lax.all_gather( + global_group_sizes, axis_name=self.expert_axis_name) + ( + compute_inputs, + local_sorted_indices, + compute_group_sizes, + compute_expert_ids, + ) = self._local_permute( + inputs_after_all2all, + full_global_group_sizes, + local_expert_size, + shard_index=expert_shard_id, + is_offset=False, + ) + + else: + # When token replicated in devices + + # 2. No send all-to-all needed, as the tokens are sorted and replicated on all devices + # 3b. Local "Permute" + ( + compute_inputs, + local_sorted_indices, + compute_group_sizes, + compute_expert_ids, + ) = self._local_permute( + sorted_inputs, + global_group_sizes[None, :], + local_expert_size, + shard_index=expert_shard_id, + is_offset=True, + global_sorted_experts=global_sorted_experts, + ) + + # Calculate group sizes for return all-to-all + reshaped_group_sizes = jnp.sum(global_group_sizes.reshape( + -1, local_expert_size), + axis=1) + mask = compute_expert_ids < local_expert_size + compute_inputs = compute_inputs * mask[..., None] + + else: + # --- NO EXPERT PARALLELISM --- + compute_inputs = sorted_inputs + compute_group_sizes = global_group_sizes + compute_expert_ids = global_sorted_experts + local_sorted_indices = jnp.arange(sorted_inputs.shape[0]) + + # 4. Compute: Apply experts using Grouped Matrix Multiply + with jax.named_scope("gating"): + # compute_inputs: (local total assignments, D) + gating_TEF = self._gmm(compute_inputs, kernel_gating, + compute_group_sizes) + activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act]( + gating_TEF) + + with jax.named_scope("up_projection"): + up_proj_TEF = self._gmm(compute_inputs, kernel_up_proj, + compute_group_sizes) + + fuse_TEF = activated_gating_TEF * up_proj_TEF + + with jax.named_scope("down_projection"): + # intermediate_output: (local total assignments, D) + intermediate_output = self._gmm(fuse_TEF, kernel_down_proj, + compute_group_sizes) + + # 5. Return Results (All-to-All) + if self.num_expert_parallelism > 1: + local_total_assignments = x_TD.shape[0] * self.num_experts_per_tok + D = x_TD.shape[1] + output_shape = jnp.zeros( + (local_total_assignments, D), + dtype=intermediate_output.dtype) + + if self.is_batch_sharded_by_expert: + # When token sharded in devices + # Unsort locally before sending back + local_output = self._sort_activations( + intermediate_output, jnp.argsort(local_sorted_indices)) + + input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params( + jnp.transpose(all_shards_group_sizes), + expert_shard_id, + self.num_expert_parallelism, + ) + final_intermediate_output = jax.lax.ragged_all_to_all( + local_output, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self.expert_axis_name) + else: + # When token replicated in devices + input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params( + reshaped_group_sizes, + expert_shard_id, + self.num_expert_parallelism, + is_batch_sharded=False, + ) + final_intermediate_output = jax.lax.ragged_all_to_all( + intermediate_output, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self.expert_axis_name) + else: + final_intermediate_output = intermediate_output + + # 6. Global Unpermute (on the data shard) + with jax.named_scope("unpermute"): + output_TD = self._unpermute(final_intermediate_output, + global_sort_indices, router_weights_TX) + + return output_TD + + def __call__(self, x_TD: Float): + """Performs the forward pass of the Sparse MoE layer.""" + x_TD = jnp.asarray(x_TD, self.dtype) + x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td) + router_weights_TX, selected_experts_TX = self.router(x_TD) + + in_specs = ( + PartitionSpec(), # Replicated `self` + PartitionSpec(*self.activation_ffw_td), # Sharded x_TD + PartitionSpec(), # Replicated router_weights_TX + PartitionSpec(), # Replicated selected_experts_TX + PartitionSpec(*self.edf_sharding), # Sharded gating kernel + PartitionSpec(*self.edf_sharding), # Sharded up-projection kernel + PartitionSpec( + *self.efd_sharding), # Sharded down-projection kernel + ) + out_specs = PartitionSpec(*self.activation_ffw_td) + + mapped_moe_fwd = partial(jax.experimental.shard_map.shard_map, + mesh=self.mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False)( + SparseMoE._distributed_sparse_moe_fwd) + + return mapped_moe_fwd( + self, + x_TD, + router_weights_TX, + selected_experts_TX, + self.kernel_gating_EDF.value, + self.kernel_up_proj_EDF.value, + self.kernel_down_proj_EFD.value, + ) diff --git a/tpu_inference/layers/jax/attention/deepseek_v3_attention.py b/tpu_inference/layers/jax/attention/deepseek_v3_attention.py index a1634b923..3f086ee1a 100644 --- a/tpu_inference/layers/jax/attention/deepseek_v3_attention.py +++ b/tpu_inference/layers/jax/attention/deepseek_v3_attention.py @@ -317,13 +317,13 @@ def attention( self.query_tnh, # q self.keyvalue_skh, # k self.keyvalue_skh, # v - P(None, None, "model"), # kv_cache + P(None, None, ('model', 'expert')), # kv_cache P(), # md.seq_lens: Replicated P(), # page_indices_flat: Replicated P(), # query_start_loc: Replicated P(), # distribution: Replicated ) - out_specs = (self.attn_o_tnh, P(None, None, "model")) + out_specs = (self.attn_o_tnh, P(None, None, ('model', 'expert'))) def _ragged_paged_attention(*args): return ragged_paged_attention( diff --git a/tpu_inference/layers/jax/moe/deepseek_v3_moe.py b/tpu_inference/layers/jax/moe/deepseek_v3_moe.py index 4aff8b9e8..8471f0a9a 100644 --- a/tpu_inference/layers/jax/moe/deepseek_v3_moe.py +++ b/tpu_inference/layers/jax/moe/deepseek_v3_moe.py @@ -19,7 +19,7 @@ manually_quantize_qwix_activation, manually_quantize_qwix_weight) modeling_flax_utils = FlaxUtils() - +jax.config.update("jax_ragged_dot_use_ragged_dot_instruction", True), @dataclass class DeepSeekV3Router(nnx.Module): @@ -329,8 +329,9 @@ def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array, with jax.named_scope("unpermute"): unsorted_tokens_tD = self._sort_activations( processed_tokens, jnp.argsort(sort_indices)) + D = unsorted_tokens_tD.shape[-1] reshaped_tokens_TXD = unsorted_tokens_tD.reshape( - -1, self.num_experts_per_tok, self.hidden_size) + -1, self.num_experts_per_tok, D) with jax.named_scope("combine_weights"): output_TD = jnp.einsum( "TXD,TX -> TD", @@ -394,10 +395,10 @@ def _distributed_sparse_moe_fwd( # TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis # or we sould derive it from the model init - expert_shard_id = jax.lax.axis_index(self.expert_axis_name) - local_expert_size = self.num_local_experts // self.num_expert_parallelism if self.num_expert_parallelism > 1: + expert_shard_id = jax.lax.axis_index(self.expert_axis_name) + local_expert_size = self.num_local_experts // self.num_expert_parallelism if self.is_batch_sharded_by_expert: # When token sharded in devices # In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name diff --git a/tpu_inference/layers/jax/moe/moe.py b/tpu_inference/layers/jax/moe/moe.py index c8e08ea40..a32ccb96f 100644 --- a/tpu_inference/layers/jax/moe/moe.py +++ b/tpu_inference/layers/jax/moe/moe.py @@ -84,8 +84,8 @@ class MoE(nnx.Module): router: nnx.Module activation_ffw_td: Sharding activation_ffw_ted: Sharding - edf_sharding: Sharding - efd_sharding: Sharding + edf_sharding: Sharding = () + efd_sharding: Sharding = () random_init: bool = False def __call__(self, x_TD: Float): diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index 2eb9ac9f9..881ad196b 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -202,7 +202,7 @@ def get_flax_model( model_class = _get_model_architecture( vllm_config.model_config.hf_config) jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh) - kv_cache_sharding = NamedSharding(mesh, PartitionSpec(None, None, "model")) + kv_cache_sharding = NamedSharding(mesh, PartitionSpec(None, None, ('model', 'expert'))) hidden_states_sharding = NamedSharding(mesh, PartitionSpec(None, None)) # (T, D) @@ -224,7 +224,7 @@ def run_model(graphdef, state, *args): model = nnx.merge(graphdef, state) return model(*args) - logits_sharding = NamedSharding(mesh, PartitionSpec(None, "model")) + logits_sharding = NamedSharding(mesh, PartitionSpec(None, ('model', 'expert'))) @functools.partial( jax.jit, diff --git a/tpu_inference/models/jax/deepseek_v3.py b/tpu_inference/models/jax/deepseek_v3.py index 4a28f4c75..eecd3dfa3 100644 --- a/tpu_inference/models/jax/deepseek_v3.py +++ b/tpu_inference/models/jax/deepseek_v3.py @@ -120,7 +120,7 @@ def __init__(self, hidden_size=hidden_size, dtype=dtype, rngs=self.rng, - vd_sharding=(('data', 'expert', 'model'), + vd_sharding=(('data', 'model', 'expert'), None), random_init=self.random_init) @@ -148,14 +148,14 @@ def _create_mla() -> MLA: rngs=self.rng, activation_attention_td=(None, None), activation_q_td=(None, None), - query_tnh=P(None, 'model', None), - keyvalue_skh=P(None, 'model', None), + query_tnh=P(None, ('model', 'expert'), None), + keyvalue_skh=P(None, ('model', 'expert'), None), activation_attention_out_td=(None, None), - attn_o_tnh=P(None, 'model', None), - q_da_sharding=(None, 'model'), - anh_sharding=(None, 'model', None), - kv_da_sharding=(None, 'model'), - nhd_sharding=('model', None, None)) + attn_o_tnh=P(None, ('model', 'expert'), None), + q_da_sharding=(None, ('model', 'expert')), + anh_sharding=(None, ('model', 'expert'), None), + kv_da_sharding=(None, ('model', 'expert')), + nhd_sharding=(('model', 'expert'), None, None)) for i in range(first_k_dense_replace): block = TransformerBlock( @@ -201,8 +201,8 @@ def _create_mla() -> MLA: routed_scaling_factor=2.5, dtype=dtype, activation_ffw_td=('data', None), - ed_sharding=('model', None), - e_sharding=('model', )) + ed_sharding=(None, None), + e_sharding=(None, )) if self.sparse_matmul: # TODO: orginize the SparseMoE and DenseMoE better given they share most interfaces custom_module = SparseMoE( @@ -216,12 +216,10 @@ def _create_mla() -> MLA: hidden_act=hidden_act, rngs=self.rng, random_init=self.random_init, - activation_ffw_td=('data', None), - activation_ffw_ted=('data', None, None), - edf_sharding=('model', None, None), - efd_sharding=('model', None, None), - quantized_dtype=self.weight_loader.quant_dtype - if self.weight_loader.is_model_quantized else None, + activation_ffw_td=('data', 'model'), + activation_ffw_ted=('data', None, 'model'), + edf_sharding=(None , 'model', 'expert'), + efd_sharding=(None , 'expert', 'model'), router=router) if is_moe_layer else DenseFFW( dtype=dtype, hidden_act=hidden_act, @@ -241,10 +239,10 @@ def _create_mla() -> MLA: hidden_act=hidden_act, rngs=self.rng, random_init=self.random_init, - activation_ffw_td=('data', None), + activation_ffw_td=('data', 'model'), activation_ffw_ted=('data', None, None), - edf_sharding=('model', None, None), - efd_sharding=('model', None, None), + edf_sharding=('expert', 'model', None), + efd_sharding=('expert', None, 'model'), router=router) if is_moe_layer else DenseFFW( dtype=dtype, hidden_act=hidden_act, @@ -304,8 +302,8 @@ def _create_mla() -> MLA: hidden_size=hidden_size, dtype=dtype, rngs=self.rng, - vd_sharding=(('data', 'expert', 'model'), None), - dv_sharding=(None, ('data', 'expert', 'model')), + vd_sharding=(('data', 'model', 'expert'), None), + dv_sharding=(None, ('data', 'model', 'expert')), random_init=self.random_init) # For compatibility with flax. @@ -365,7 +363,6 @@ def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size, "is_verbose", None) is not None self.num_routed_experts = num_local_experts self.model_dtype = model_dtype - self._transpose_map = { # dense mlp r"mlp\.down_proj": (1, 0), @@ -829,9 +826,10 @@ def load_weights(self, model_for_loading: nnx.Module): def weights_dequant_cpu(x: torch.Tensor, s: torch.Tensor, - output_dtype: jnp.dtype, + output_dtype: torch.dtype, block_size: int = 128) -> torch.Tensor: assert x.dim() == 2 and s.dim() == 2, "Both x and s must be 2D tensors" + torch_output_type = DTYPE_VIEW_MAP.get(jnp.dtype(output_dtype)) M, N = x.shape x = x.to(torch.float32) @@ -865,4 +863,4 @@ def weights_dequant_cpu(x: torch.Tensor, scale = s[M // block_size, j // block_size] y[M_main:M, j:j + block_size] = block * scale - return y.to(j2t_dtype(jnp.dtype(output_dtype))) + return y.to(torch_output_type) diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 597f3c51a..3d815b458 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -251,7 +251,7 @@ def _precompile_select_from_array(self) -> None: indices_paddings=self.runner.num_reqs_paddings, hidden_dim=vocab_size, sharding=NamedSharding(self.runner.mesh, - PartitionSpec(None, "model")), + PartitionSpec(None, ('model', 'expert'))), ) self._precompile_select_from_array_helper( name="select target tokens for spec decoding", @@ -259,7 +259,7 @@ def _precompile_select_from_array(self) -> None: indices_paddings=self.runner.num_logits_paddings, hidden_dim=vocab_size, sharding=NamedSharding(self.runner.mesh, - PartitionSpec(None, "model")), + PartitionSpec(None, ('model', 'expert'))), only_equal_paddings=True, ) @@ -288,7 +288,7 @@ def _precompile_sampling(self) -> None: hsize = self.runner.model_config.get_vocab_size() for num_reqs in self.runner.num_reqs_paddings: sharding = NamedSharding(self.runner.mesh, - PartitionSpec(None, "model")) + PartitionSpec(None, ('model', 'expert'))) logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16, sharding) for do_sampling in (True, False): @@ -373,7 +373,7 @@ def _precompile_rejection_sampler(self) -> None: for num_logits in self.runner.num_logits_paddings: for num_reqs in self.runner.num_reqs_paddings: sharding = NamedSharding(self.runner.mesh, - PartitionSpec(None, "model")) + PartitionSpec(None, ('model', 'expert'))) target_probs = self._create_dummy_tensor( (num_logits, vocab_size), jnp.bfloat16, sharding) draft_token_ids = self._create_dummy_tensor((num_logits, ), diff --git a/tpu_inference/runner/kv_cache.py b/tpu_inference/runner/kv_cache.py index 236e86c5d..ccf2b674a 100644 --- a/tpu_inference/runner/kv_cache.py +++ b/tpu_inference/runner/kv_cache.py @@ -20,7 +20,7 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int, actual_head_dim: int, kv_dtype: any): """Gets the KV cache shape based on the mesh configuration.""" - model_cnt = mesh.shape["model"] + model_cnt = mesh.shape["model"] * mesh.shape["expert"] assert actual_num_kv_heads % model_cnt == 0 shape = list( rpa.get_kv_cache_shape(total_num_pages, page_size, @@ -66,7 +66,7 @@ def create_kv_caches( num_kv_heads, head_size, cache_dtype) - sharding = NamedSharding(mesh, PartitionSpec(None, None, "model")) + sharding = NamedSharding(mesh, PartitionSpec(None, None, ('model', 'expert'))) def _allocate() -> jax.Array: return jnp.empty( diff --git a/tpu_inference/runner/kv_cache_manager.py b/tpu_inference/runner/kv_cache_manager.py index 250decbdf..f7977228e 100644 --- a/tpu_inference/runner/kv_cache_manager.py +++ b/tpu_inference/runner/kv_cache_manager.py @@ -55,7 +55,7 @@ def get_kv_cache_spec(self): # Pad num_kv_heads to multiple of TP size. num_kv_heads = common_utils.get_padded_num_heads( model_config.get_total_num_kv_heads(), - self.runner.mesh.shape["model"]) + self.runner.mesh.shape["model"] * self.runner.mesh.shape["expert"]) head_size = common_utils.get_padded_head_dim( model_config.get_head_size()) for i in range(model_config.get_num_layers(parallel_config)): @@ -78,7 +78,7 @@ def get_kv_cache_spec(self): hf_config = draft_model_config.hf_config num_kv_heads = common_utils.get_padded_num_heads( hf_config.num_key_value_heads, - self.runner.mesh.shape["model"]) + self.runner.mesh.shape["model"] * self.runner.mesh.shape["expert"]) head_size = common_utils.get_padded_head_dim( hf_config.hidden_size // hf_config.num_attention_heads) @@ -120,7 +120,7 @@ def get_kv_cache_spec(self): block_size=block_size, num_kv_heads=common_utils.get_padded_num_heads( attn_module.num_kv_heads, - self.runner.mesh.shape["model"]), + self.runner.mesh.shape["model"] * self.runner.mesh.shape["expert"]), head_size=common_utils.get_padded_head_dim( attn_module.head_size), dtype=self.runner.kv_cache_dtype, @@ -138,7 +138,7 @@ def get_kv_cache_spec(self): block_size=block_size, num_kv_heads=common_utils.get_padded_num_heads( attn_module.num_kv_heads, - self.runner.mesh.shape["model"]), + self.runner.mesh.shape["model"] * self.runner.mesh.shape["expert"]), head_size=common_utils.get_padded_head_dim( attn_module.head_size), dtype=self.runner.kv_cache_dtype) @@ -375,7 +375,7 @@ def transfer_kv_cache(self, f"Transferring kv cache shape {len(kv_cache_slices)} * {kv_cache_slices[0].shape} sharding {kv_cache_slices[0].sharding} size {kv_cache_slices[0].nbytes * len(kv_cache_slices)/1024/1024} Mbytes" ) sharding = NamedSharding(self.runner.mesh, - PartitionSpec(None, "model")) + PartitionSpec(None, ("model", "expert"))) if envs.VLLM_TPU_USING_PATHWAYS: from pathwaysutils.experimental import \ reshard as experimental_reshard