diff --git a/recipes/configs/llama4/scout_17B_16E_full.yaml b/recipes/configs/llama4/scout_17B_16E_full.yaml index 8497b81ca3..cda341c6ed 100644 --- a/recipes/configs/llama4/scout_17B_16E_full.yaml +++ b/recipes/configs/llama4/scout_17B_16E_full.yaml @@ -18,7 +18,7 @@ output_dir: /tmp/torchtune/llama4_17Bx16E/full model: _component_: torchtune.models.llama4.llama4_scout_17b_16e -tensor_parallel_dim: 2 # For multi-node training we recommend tensor_parallel_dim: 8 +tensor_parallel_dim: 1 # For multi-node training we recommend tensor_parallel_dim: 8 tensor_parallel_plan: _component_: torchtune.models.llama4.decoder_only_tp_plan data_parallel_shard_dim: -1 # Will infer based on TP dim, effectively controls FSDP @@ -74,10 +74,10 @@ fsdp_cpu_offload: True # compile Dictionary with keys: "model", "loss", "optimizer_step" # enables torch.compile only for specified components. compile: False -# model: True -# loss: True -# optimizer_step: False -# scale_grads: True + # model: True + # loss: True + # optimizer_step: True + # scale_grads: True # Reduced precision dtype: bf16 @@ -93,4 +93,17 @@ log_level: INFO # DEBUG, WARN, etc. # Useful for understanding how to optimize memory and performance profiler: _component_: torchtune.training.setup_torch_profiler - enabled: False + enabled: True + output_dir: ${output_dir}/profiling_outputs + cpu: True + cuda: True + profile_memory: True + with_stack: True + record_shapes: True + with_flops: False + wait_steps: 3 + warmup_steps: 3 + active_steps: 1 + num_cycles: 1 + +# enable_fp8_training: True diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index df89d96c21..a808276fbc 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -16,6 +16,7 @@ from omegaconf import DictConfig, ListConfig from torch import nn +import torch.distributed as dist from torch.distributed import destroy_process_group, init_process_group from torch.distributed.tensor import DTensor from torch.distributed.tensor.parallel import parallelize_module @@ -147,6 +148,10 @@ def __init__(self, cfg: DictConfig) -> None: offload_ops_to_cpu=self.fsdp_cpu_offload or self._enable_async_checkpointing, ) + # group_name = "torchtune-finetune" + # pg = dist.distributed_c10d._get_default_group() + # torch._C._distributed_c10d._register_process_group(group_name, pg) + # init_process_group(self.distributed_backend, group_name=group_name) init_process_group(self.distributed_backend) # Initialize distributed variables @@ -328,6 +333,9 @@ def setup(self, cfg: DictConfig) -> None: compile = cfg.get("compile") compile_bool = bool(compile) self._compile_backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + self._compile_mode = None # "max-autotune-no-cudagraphs" + # torch._inductor.config.cpp_wrapper = True + # torch._dynamo.config.capture_scalar_outputs = True self._compile_model = compile_bool self._compile_loss = compile_bool @@ -343,7 +351,7 @@ def setup(self, cfg: DictConfig) -> None: self._grad_scaler = training.scale_grads_ if self._compile_scale_grads: self._grad_scaler = torch.compile( - self._grad_scaler, backend=self._compile_backend + self._grad_scaler, backend=self._compile_backend, mode=self._compile_mode ) self._model = self._setup_model( @@ -380,6 +388,7 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer.step = torch.compile( self._optimizer.step, backend=self._compile_backend, + mode=self._compile_mode ) if self._resume_from_checkpoint: @@ -413,7 +422,7 @@ def setup(self, cfg: DictConfig) -> None: self._loss_fn.set_model_output(self._model) if self._compile_loss: - training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + training.compile_loss(self._loss_fn, mode=self._compile_mode, verbose=self._is_rank_zero) utils.log_rank_zero(self._logger, "Loss is initialized.") @@ -586,7 +595,7 @@ def _setup_model( model = config.instantiate(cfg_model) if self._compile_model: - training.compile_model(model, verbose=self._is_rank_zero) + training.compile_model(model, mode=self._compile_mode, verbose=self._is_rank_zero) if self._enable_fp8_training: # Requires https://github.com/pytorch/pytorch/pull/148922 @@ -810,6 +819,7 @@ def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: with self.activations_handling_ctx: outputs = self._model(**batch) + # print(f"XXX {dist.get_rank()} OUTPUTS:{outputs.shape} {outputs.dtype}") # post process for third party loss functions if not isinstance(self._loss_fn, SFTLoss): @@ -820,6 +830,7 @@ def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: # Compute loss loss = self._loss_fn(outputs, labels) + # print(f"XXX {dist.get_rank()} LOSS:{loss}") # free logits otherwise it peaks backward memory del outputs @@ -895,6 +906,9 @@ def train(self) -> None: pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero) self._dataloader.sampler.set_epoch(curr_epoch) for idx, batch in enumerate(self._dataloader): + b_tokens = batch["tokens"] + b_labels = batch["labels"] + # print(f"XXX R:{dist.get_rank()} BATCH:{idx} b_labels:{b_labels.shape} b_tokens:{b_tokens.shape}") # Start tracking CUDA memory for active steps for just the first epoch if ( self._is_rank_zero @@ -916,7 +930,9 @@ def train(self) -> None: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients + # print(f"XXX R:{dist.get_rank()} BATCH:{idx} current_num_tokens:{current_num_tokens}") current_loss = self._loss_step(batch) * current_num_tokens + # print(f"XXX R:{dist.get_rank()} BATCH:{idx} current_loss:{current_loss}") running_loss += current_loss # For optimizer in backward, we need to normalize before calling backward @@ -1068,6 +1084,26 @@ def cleanup(self) -> None: self._metric_logger.close() destroy_process_group() +# from torch.utils._python_dispatch import TorchDispatchMode +# import torch.utils._pytree as pytree +# from torch._higher_order_ops.flex_attention import flex_attention +# +# +# +# class Mode(TorchDispatchMode): +# def __torch_dispatch__(self, func, types, args=(), kwargs=None): +# r = torch.distributed.get_rank() +# print(f"XXX RANK[{r}] MODE._torch_dispatch_ {func} {types}") +# for a in pytree.tree_leaves(args): +# if issubclass(type(a), torch.Tensor): +# print(f"XXX RANK[{r}] {a.dtype} {a.shape}") +# else: +# print(f"XXX RANK[{r}] {a}") +# return func(*args, **kwargs) +# +# def flex_attention_mode_call(mode, *args, **kwargs): +# return flex_attention(*args, **kwargs) +# flex_attention.py_impl(Mode)(flex_attention_mode_call) @config.parse def recipe_main(cfg: DictConfig) -> None: @@ -1081,6 +1117,7 @@ def recipe_main(cfg: DictConfig) -> None: config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg) recipe = FullFinetuneRecipeDistributed(cfg=cfg) recipe.setup(cfg=cfg) + # with Mode(): recipe.train() recipe.cleanup() diff --git a/torchtune/models/llama4/_component_builders.py b/torchtune/models/llama4/_component_builders.py index d08370b064..bf9ae49fd4 100644 --- a/torchtune/models/llama4/_component_builders.py +++ b/torchtune/models/llama4/_component_builders.py @@ -38,6 +38,7 @@ TokenChoiceTopKRouter, ) from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear +from torchtune.utils._device import has_cuda_capability """ Component builders for the Llama4 model. @@ -180,6 +181,7 @@ def llama4_decoder( num_experts: int = 16, experts_per_token: int = 1, use_shared_expert: bool = True, + use_grouped_mm: bool = True, use_qk_norm: bool = True, moe_every_n_layers: Optional[int] = None, mlp_hidden_dim: Optional[int] = None, @@ -244,6 +246,11 @@ def llama4_decoder( raise ValueError( "Must pass local_chunk_size when enabling local chunked attention" ) + if use_grouped_mm and not has_cuda_capability(9, 0): + torchtune.utils.get_logger("WARNING")( + "Failed to use grouped mm, which is only supported on SM90 or later", + ) + use_grouped_mm = False head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads @@ -263,7 +270,6 @@ def llama4_decoder( ) layers = [] for i in range(num_layers): - mask_mod = None if skip_rope_interval is not None and (i + 1) % skip_rope_interval != 0: mask_mod = partial( @@ -300,6 +306,7 @@ def llama4_decoder( num_experts=num_experts, experts_per_token=experts_per_token, use_shared_expert=use_shared_expert, + use_grouped_mm=use_grouped_mm, ) else: mlp_layer = llama4_mlp(dim=embed_dim, hidden_dim=mlp_hidden_dim) @@ -355,6 +362,7 @@ def llama4_moe( num_experts: int = 8, experts_per_token: int = 1, use_shared_expert: bool = True, + use_grouped_mm: bool = True, ) -> MoE: """ Build the MoE layer associated with the Llama model. @@ -631,6 +639,7 @@ def lora_llama4_decoder( raise ValueError( "Must pass local_chunk_size when enabling local chunked attention" ) + head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads if use_scaled_rope: @@ -649,7 +658,6 @@ def lora_llama4_decoder( ) layers = [] for i in range(num_layers): - mask_mod = None if skip_rope_interval is not None and (i + 1) % skip_rope_interval != 0: mask_mod = partial( diff --git a/torchtune/modules/attention_utils.py b/torchtune/modules/attention_utils.py index f2f4985029..e4eb7390db 100644 --- a/torchtune/modules/attention_utils.py +++ b/torchtune/modules/attention_utils.py @@ -47,7 +47,7 @@ def compile_flex_attention(): # when compiled. To insulate it from the compiler, we wrap it with # compiler.disable so that it can be used regardless of whether the model # is compiled or not, and flex attention always remains compiled. - @torch.compiler.disable(recursive=False) + # @torch.compiler.disable(recursive=False) def compile_friendly_flex_attention( q: torch.Tensor, k: torch.Tensor, diff --git a/torchtune/modules/moe/experts.py b/torchtune/modules/moe/experts.py index 8b7984c786..c7a00d4acb 100644 --- a/torchtune/modules/moe/experts.py +++ b/torchtune/modules/moe/experts.py @@ -13,6 +13,11 @@ from torchtune.modules.peft import AdapterModule +@torch._dynamo.allow_in_graph +def _grouped_mm(x, w, offs): + return torch._grouped_mm(x, w, offs=offs) + + class GroupedExperts(nn.Module): """This class implements the grouped experts layer used in Mixture of Experts. Each expert is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202. @@ -31,6 +36,7 @@ def __init__( hidden_dim: int, num_experts: int = 1, activation: Callable = F.silu, + use_grouped_mm: bool = False, ): super().__init__() self.dim = dim @@ -39,6 +45,8 @@ def __init__( self.down_proj = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) self.up_proj = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) self.act_fn = activation + self.rank = torch.distributed.get_rank() + self.use_grouped_mm = use_grouped_mm def reset_parameters(self) -> None: # Default initialization used by torch.nn.Linear @@ -50,6 +58,7 @@ def reset_parameters(self) -> None: # TODO: force no inference mode as a hack to get around # "Cannot set version_counter for inference tensor" @torch.inference_mode(mode=False) + @torch._dynamo.disable(recursive=False) def forward( self, x: torch.Tensor, @@ -64,28 +73,59 @@ def forward( Returns: torch.Tensor: tensor with shape ``(bsz * seq_len * experts_per_token, dim)`` """ - - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x, - split_size_or_sections=num_tokens_per_expert.tolist(), - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - w1, w2, w3 = ( - self.gate_proj[expert_idx], - self.down_proj[expert_idx], - self.up_proj[expert_idx], + if not self.use_grouped_mm: + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + num_tokens_per_expert_list = num_tokens_per_expert.tolist() + if torch.compiler.is_compiling(): + for n in num_tokens_per_expert_list: + torch._check_is_size(n) + x = torch.split( + x, + split_size_or_sections=num_tokens_per_expert_list, + dim=0, ) - h = self.act_fn(torch.matmul(x_expert, w1)) - h = h * torch.matmul(x_expert, w3) - h = torch.matmul(h, w2) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + w1, w2, w3 = ( + self.gate_proj[expert_idx], + self.down_proj[expert_idx], + self.up_proj[expert_idx], + ) + h = self.act_fn(torch.matmul(x_expert, w1)) + h = h * torch.matmul(x_expert, w3) + h = torch.matmul(h, w2) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + return out + # grouped mm implementation + if num_tokens_per_expert is not None: + # https://github.com/pytorch/pytorch/pull/150374 + # NOTE: torch._gouped_mm requires bf16 dtypes + # and shapes to be multiple of 8 + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + w1, w2, w3 = ( + self.gate_proj, + self.down_proj, + self.up_proj, + ) + assert ( + x.dtype == w1.dtype == w2.dtype == w3.dtype == torch.bfloat16 + ), "torch._grouped_mm only supports bf16 dtypes" + h = F.silu(_grouped_mm(x, w1, offs=offsets)) + h = h * _grouped_mm(x, w3, offs=offsets) + out = _grouped_mm(h, w2, offs=offsets) + out[offsets[-1] :].zero_() return out diff --git a/torchtune/modules/moe/indices.py b/torchtune/modules/moe/indices.py new file mode 100644 index 0000000000..d509272ade --- /dev/null +++ b/torchtune/modules/moe/indices.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import triton +import triton.language as tl + + +__all__ = ["generate_permute_indices"] + + +# parallelized kernel +@triton.jit +def _fill_indices_kernel( + tokens_per_expert_group_ptr, + start_index_values_ptr, + write_offsets_ptr, + output_ptr, + total_tokens_per_expert_ptr, # Added to check for zero tokens + experts_per_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # Number of threads per block +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + # map programs (blocks) to the experts and loop (grid stride) if needed + for expert_id in range(pid, experts_per_rank, num_programs): + # read this experts write offset + write_offset = tl.load(write_offsets_ptr + expert_id) + + # get total tokens for this expert across all ranks + total_expert_tokens = tl.load(total_tokens_per_expert_ptr + expert_id) + + # loop over all ranks, skip if no tokens + if total_expert_tokens > 0: + for r in range(num_ranks): + # index into tokens_per_expert_group array + i = r * experts_per_rank + expert_id + + # load start index and number of tokens for this expert-rank pair + start_index = tl.load(start_index_values_ptr + i) + length = tl.load(tokens_per_expert_group_ptr + i) + + # we can skip this rank-expert pair if there are no tokens + if length > 0: + # each thread in block processes tokens in parallel + offsets = tl.arange(0, BLOCK_SIZE) + + # tokens are processed in chunks of BLOCK_SIZE + for chunk_start in range(0, length, BLOCK_SIZE): + chunk_offsets = chunk_start + offsets + + # mask valid indices + mask = chunk_offsets < length + + values = start_index + chunk_offsets + + # destination + dest_indices = write_offset + chunk_offsets + + # store + tl.store(output_ptr + dest_indices, values, mask=mask) + + # update write offset for next rank + write_offset += length + + +# ============== +# wrapper +# ============== + + +def fill_indices_wrapper( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + total_tokens_per_expert: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + block_size: int = 128, + max_blocks: int = 1024, # cap on total number of blocks to launch +): + # preallocate output + permuted_indices = torch.full( + (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device + ) + + # write offsets is per local expert... + num_blocks = min(experts_per_rank, max_blocks) + # grid = one block per expert unless capped and then we loop... + grid = (num_blocks,) + + # launch kernel + _fill_indices_kernel[grid]( + tokens_per_expert_group, + start_index_values, + write_offsets, + permuted_indices, + total_tokens_per_expert, # 'skip logic' check for zero tokens + experts_per_rank, + num_ranks, + BLOCK_SIZE=block_size, + ) + return permuted_indices + + +# reference +def fill_indices_cpu( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + total_tokens_per_expert: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, +): + # We need to preallocate the output - we ignore device and force it on cpu + # device = tokens_per_expert_group.device + permuted_indices = torch.full( + (max_len,), + -1, + dtype=torch.int32, + ) # device=device) + # Fill the permuted indices + # For each local expert + for e in range(experts_per_rank): + write_start = write_offsets[e].item() + total_tokens = total_tokens_per_expert[e].item() + + # For each remote rank + # we can skip this expert if it has no tokens, already filled with -1 + if total_tokens > 0: + for r in range(num_ranks): + i = r * experts_per_rank + e + start_index = start_index_values[i].item() + length = tokens_per_expert_group[i].item() + # Fill in the indices + if length > 0: + end_idx = min(write_start + length, max_len) + permuted_indices[write_start:end_idx] = torch.arange( + start_index, + start_index + (end_idx - write_start), + dtype=torch.int32, + # device=device, + ) + write_start += length + return permuted_indices + + +def generate_permute_indices( + tokens_per_expert_group: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + alignment: int, + use_cpu: bool = False, +): + """ + Prepare permutation indices and the number of tokens for each expert. + + Args: + tokens_per_expert_group: number of tokens for each expert from all ranks. + experts_per_rank: number of experts per rank. + num_ranks: number of ranks. + max_len: maximum length of the output index vector. + alignment: alignment for each returned element in `m_sizes` and padding min for zero token experts. + use_cpu: whether to use CPU implementation. + + + Returns: + permuted_indices: Tensor of indices that map original token order to the expert-grouped order. + m_sizes: aligned number of tokens for each expert (padded to alignment boundary). + m_offsets: Cumulative sum of m_sizes. The exclusive ending position for each expert's tokens. + + Explanatory details: + `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example: + From: | rank 0 | rank 1 | + To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 | + | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 | + """ + + # prefix sum to get start index of each expert (parallel scan kernel in future?) + start_index_values = ( + torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group + ) + + # total tokens for each expert (sum over ranks) + total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) + + # pad out empty experts to alignment requirement + padded_total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment) + + # align the chunk sizes (cdiv) + m_sizes = ( + (padded_total_tokens_per_expert + alignment - 1) // alignment * alignment + ).to(torch.int32) + + # additional prefix sum to get write offset of each expert in permuted_indices + # write offsets is per local expert, not global + m_offsets = torch.cumsum(m_sizes, 0) + write_offsets = m_offsets - m_sizes + + # Select the implementation to use + if use_cpu: + permuted_indices = fill_indices_cpu( + tokens_per_expert_group, + start_index_values, + write_offsets, + total_tokens_per_expert, # Pass to check for zero tokens + experts_per_rank, + num_ranks, + max_len, + ) + else: + permuted_indices = fill_indices_wrapper( + tokens_per_expert_group, + start_index_values, + write_offsets, + total_tokens_per_expert, # Pass to check for zero tokens + experts_per_rank, + num_ranks, + max_len, + ) + + return permuted_indices, m_sizes, m_offsets.to(torch.int32) + + +# Below is for testing only + + +def simple_test(): + device = torch.device("cuda", 0) + experts_per_rank = 4 + num_ranks = 4 + tokens_per_expert_group = torch.full( + (num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device + ) + max_len = 128 + alignment = 32 + # Use the GPU kernel + permuted_indices_gpu, m_sizes, _ = generate_permute_indices( + tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment + ) + # Use the CPU method + permuted_indices_cpu, m_sizes, _ = generate_permute_indices( + tokens_per_expert_group, + experts_per_rank, + num_ranks, + max_len, + alignment, + use_cpu=True, + ) + # Check that the results are the same + + assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu) + assert torch.equal( + torch.remainder(m_sizes, alignment), + torch.zeros(experts_per_rank, device=device), + ) + # Print the results + print(f"{permuted_indices_gpu=}, \n{permuted_indices_cpu=}") + print(f"{m_sizes=}") + print("Success") + return True # assert would have failed meaning getting here is success. + + +def test_with_zero_tokens(): + device = torch.device("cuda", 0) + experts_per_rank = 4 + num_ranks = 2 + + # Create a test case where some experts have zero tokens + tokens_per_expert_group = torch.tensor( + [4, 0, 2, 3, 1, 0, 0, 5], # Some experts have zero tokens + dtype=torch.int32, + device=device, + ) + + max_len = 128 + alignment = 8 + + # Use the GPU kernel + permuted_indices_gpu, m_sizes, m_offsets = generate_permute_indices( + tokens_per_expert_group, + experts_per_rank, + num_ranks, + max_len, + alignment, + ) + + # Use the CPU method + permuted_indices_cpu, m_sizes_cpu, m_offsets_cpu = generate_permute_indices( + tokens_per_expert_group, + experts_per_rank, + num_ranks, + max_len, + alignment, + use_cpu=True, + ) + + # Check that the results are the same + assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu) + assert torch.equal(m_sizes, m_sizes_cpu) + + # Verify that experts with zero tokens have at least min_slots_per_expert + total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) + zero_token_experts = total_tokens_per_expert == 0 + if zero_token_experts.any(): + assert (m_sizes[zero_token_experts] >= alignment).all() + + # Check alignment + assert torch.equal( + torch.remainder(m_sizes, alignment), + torch.zeros(experts_per_rank, device=device), + ) + + # Print the results + print(f"tokens_per_expert_group = {tokens_per_expert_group}") + print(f"total_tokens_per_expert = {total_tokens_per_expert}") + print(f"m_sizes = {m_sizes}") + print(f"m_offsets = {m_offsets}") + print(f"permuted_indices = {permuted_indices_gpu[:sum(m_sizes).item()]}") + + # Check that experts with zero tokens have -1 in their slots + for e in range(experts_per_rank): + start = (m_offsets[e] - m_sizes[e]).item() + end = m_offsets[e].item() + expert_indices = permuted_indices_gpu[start:end] + if total_tokens_per_expert[e] == 0: + assert ( + expert_indices == -1 + ).all(), f"Expert {e} with zero tokens should have all -1 indices" + assert ( + expert_indices.size(0) >= alignment + ), f"Expert {e} with zero tokens should have at least {alignment} slots" + print( + f"Expert {e} has zero tokens and {expert_indices.size(0)} slots with all -1" + ) + + print("All tests passed successfully!") + return True + + +if __name__ == "__main__": + simple_test() + test_with_zero_tokens() diff --git a/torchtune/modules/moe/moe.py b/torchtune/modules/moe/moe.py index b6fd008356..87dadba6c9 100644 --- a/torchtune/modules/moe/moe.py +++ b/torchtune/modules/moe/moe.py @@ -99,11 +99,13 @@ def __init__( experts: nn.Module, router: nn.Module, shared_expert: Optional[nn.Module] = None, + use_grouped_mm: bool = False, ): super().__init__() self.experts = experts self.router = router self.shared_expert = shared_expert + self.use_grouped_mm = use_grouped_mm def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -133,6 +135,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) routed_input = routed_input * top_scores.reshape(-1, 1) + if self.use_grouped_mm: + # NOTE: In order to use torch._grouped_mm, we need to make sure + # the number of tokens each expert gets is a multiple of 16. + # The following kernel helps achieve this via padding, without + # incurring synchronization between device and host. + from torchtune.modules.moe.indices import generate_permute_indices + + ALIGN_SIZE_M = 16 + + with torch.no_grad(): + ( + permuted_indices, + num_tokens_per_expert, + _, + ) = generate_permute_indices( + num_tokens_per_expert, + self.experts.num_experts, + 1, + token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M, + ALIGN_SIZE_M, + ) + token_indices = torch.vstack( + (token_indices, token_indices.new_zeros((dim))) + ) + token_indices = token_indices[permuted_indices, :] + routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim)))) + routed_input = routed_input[permuted_indices, :] + # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) @@ -141,6 +171,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.shared_expert(x).reshape(bs * slen, dim) else: out = torch.zeros_like(x.reshape(bs * slen, dim)) + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) out = out.reshape(bs, slen, dim) return out diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 724138b14e..bd42e7db70 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -46,6 +46,7 @@ def __init__( self.sa_norm = sa_norm or nn.Identity() self.mlp_norm = mlp_norm or nn.Identity() self.sa_scale = sa_scale or nn.Identity() + print(f"XXX transformer.py self.sa_scale:{type(self.sa_scale)}") self.mlp_scale = mlp_scale or nn.Identity() self.mask_mod = mask_mod or None diff --git a/torchtune/training/_compile.py b/torchtune/training/_compile.py index 1324045253..97ceaf0ff0 100644 --- a/torchtune/training/_compile.py +++ b/torchtune/training/_compile.py @@ -23,6 +23,7 @@ def compile_model( model: Union[TransformerDecoder, DeepFusionModel], + mode = None, verbose: bool = True, ) -> None: """ @@ -48,10 +49,10 @@ def compile_model( if isinstance(m, TransformerSelfAttentionLayer) or isinstance( m, TransformerCrossAttentionLayer ): - m.compile(backend=backend) + m.compile(backend=backend, mode=mode) -def compile_loss(loss: nn.Module, verbose: bool = True) -> nn.Module: +def compile_loss(loss: nn.Module, mode=None, verbose: bool = True) -> nn.Module: """ Utility to compile and return loss function @@ -68,6 +69,6 @@ def compile_loss(loss: nn.Module, verbose: bool = True) -> nn.Module: if hasattr(loss, "apply_compile_strategy"): loss = loss.apply_compile_strategy(backend=backend) else: - loss = torch.compile(loss, backend=backend) + loss = torch.compile(loss, backend=backend, mode=mode) return loss diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index bc156599a5..80a18e0fc5 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -37,6 +37,7 @@ from torchtune.modules.peft import get_adapter_state_dict from torchtune.utils import get_device, get_logger from torchtune.utils._logging import deprecated +from torch.distributed.fsdp import MixedPrecisionPolicy _log: logging.Logger = get_logger() @@ -646,6 +647,7 @@ def shard_model( fsdp_kwargs = {"reshard_after_forward": reshard_after_forward, "mesh": dp_mesh} if cpu_offload: fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() + # fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(reduce_dtype=torch.float8_e5m2) # Shard the model with FSDP, iterating in reverse to start with # lowest-level modules first diff --git a/torchtune/training/_grad_scaler.py b/torchtune/training/_grad_scaler.py index 4ae2be70da..67fa634ff9 100644 --- a/torchtune/training/_grad_scaler.py +++ b/torchtune/training/_grad_scaler.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from collections import defaultdict -from typing import Optional +from typing import Optional, Iterable import torch from torch import nn, Tensor @@ -64,12 +64,14 @@ def scale_grads_( if isinstance(parameters, torch.Tensor): parameters = [parameters] else: - parameters = list(parameters) + # Graph Breaks + # Hint: Avoid calling builtin `list` with argument types ['generator']. Consider using an equivalent alternative function/method to `list`. + parameters = parameters _scale_grad_(parameters, scaler, foreach) def _group_tensors_by_device( - tensors: list[torch.Tensor], + tensors: Iterable[torch.Tensor], ) -> dict[torch.device, list[Tensor]]: ret = defaultdict(list) for i, tensor in enumerate(tensors): diff --git a/torchtune/training/_profiler.py b/torchtune/training/_profiler.py index c4b3d70174..b6a840cb25 100644 --- a/torchtune/training/_profiler.py +++ b/torchtune/training/_profiler.py @@ -85,8 +85,9 @@ def trace_handler( row_limit (int): number of rows to display in trace event table """ + import time world_size, rank = get_world_size_and_rank() - curr_trace_dir_name = "iteration_" + str(prof.step_num) + curr_trace_dir_name = "iteration_" + str(prof.step_num) + time.strftime("T%H-%M") curr_trace_dir = os.path.join(output_dir, curr_trace_dir_name) if not os.path.exists(curr_trace_dir): os.makedirs(curr_trace_dir, exist_ok=True) @@ -112,11 +113,14 @@ def trace_handler( log.info(f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds") # Memory timeline sometimes fails to export + print(f"XXX prof.profile_memory:{prof.profile_memory}") if prof.profile_memory and torch.cuda.is_available(): if rank == 0: + print(f"XXX DUMP_MEMORY_SNAPSHOT") torch.cuda.memory._dump_snapshot( f"{curr_trace_dir}/rank{rank}_memory_snapshot.pickle" ) + print(f"XXX PROFILE_DIR:{curr_trace_dir}") # Dump stack traces if prof.with_stack: @@ -130,6 +134,7 @@ def trace_handler( print(key_avgs, file=f) if rank == 0: log.info(f"Saving profiling results to {curr_trace_dir}") + print(f"XXX CURRENT_TRACE_DIR:{curr_trace_dir}") # TODO: Is this necessary? # see https://github.com/pytorch/torchtitan/blob/3050098dcee4901d88c712f9e8e9703d1735a29b/torchtitan/profiling.py#L48 diff --git a/torchtune/utils/_device.py b/torchtune/utils/_device.py index ec267334ab..190a4612ad 100644 --- a/torchtune/utils/_device.py +++ b/torchtune/utils/_device.py @@ -265,3 +265,10 @@ def get_torch_device_namespace() -> any: f"Device namespace '{device_type}' not found in torch, try to load torch.cuda." ) return torch.cuda + + +def has_cuda_capability(major: int, minor: int) -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + major, + minor, + )