Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions recipes/configs/llama4/scout_17B_16E_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ output_dir: /tmp/torchtune/llama4_17Bx16E/full
model:
_component_: torchtune.models.llama4.llama4_scout_17b_16e

tensor_parallel_dim: 2 # For multi-node training we recommend tensor_parallel_dim: 8
tensor_parallel_dim: 1 # For multi-node training we recommend tensor_parallel_dim: 8
tensor_parallel_plan:
_component_: torchtune.models.llama4.decoder_only_tp_plan
data_parallel_shard_dim: -1 # Will infer based on TP dim, effectively controls FSDP
Expand Down Expand Up @@ -74,10 +74,10 @@ fsdp_cpu_offload: True
# compile Dictionary with keys: "model", "loss", "optimizer_step"
# enables torch.compile only for specified components.
compile: False
# model: True
# loss: True
# optimizer_step: False
# scale_grads: True
# model: True
# loss: True
# optimizer_step: True
# scale_grads: True

# Reduced precision
dtype: bf16
Expand All @@ -93,4 +93,17 @@ log_level: INFO # DEBUG, WARN, etc.
# Useful for understanding how to optimize memory and performance
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False
enabled: True
output_dir: ${output_dir}/profiling_outputs
cpu: True
cuda: True
profile_memory: True
with_stack: True
record_shapes: True
with_flops: False
wait_steps: 3
warmup_steps: 3
active_steps: 1
num_cycles: 1

# enable_fp8_training: True
43 changes: 40 additions & 3 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from omegaconf import DictConfig, ListConfig

from torch import nn
import torch.distributed as dist
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.parallel import parallelize_module
Expand Down Expand Up @@ -147,6 +148,10 @@ def __init__(self, cfg: DictConfig) -> None:
offload_ops_to_cpu=self.fsdp_cpu_offload
or self._enable_async_checkpointing,
)
# group_name = "torchtune-finetune"
# pg = dist.distributed_c10d._get_default_group()
# torch._C._distributed_c10d._register_process_group(group_name, pg)
# init_process_group(self.distributed_backend, group_name=group_name)
init_process_group(self.distributed_backend)

# Initialize distributed variables
Expand Down Expand Up @@ -328,6 +333,9 @@ def setup(self, cfg: DictConfig) -> None:
compile = cfg.get("compile")
compile_bool = bool(compile)
self._compile_backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
self._compile_mode = None # "max-autotune-no-cudagraphs"
# torch._inductor.config.cpp_wrapper = True
# torch._dynamo.config.capture_scalar_outputs = True

self._compile_model = compile_bool
self._compile_loss = compile_bool
Expand All @@ -343,7 +351,7 @@ def setup(self, cfg: DictConfig) -> None:
self._grad_scaler = training.scale_grads_
if self._compile_scale_grads:
self._grad_scaler = torch.compile(
self._grad_scaler, backend=self._compile_backend
self._grad_scaler, backend=self._compile_backend, mode=self._compile_mode
)

self._model = self._setup_model(
Expand Down Expand Up @@ -380,6 +388,7 @@ def setup(self, cfg: DictConfig) -> None:
self._optimizer.step = torch.compile(
self._optimizer.step,
backend=self._compile_backend,
mode=self._compile_mode
)

if self._resume_from_checkpoint:
Expand Down Expand Up @@ -413,7 +422,7 @@ def setup(self, cfg: DictConfig) -> None:
self._loss_fn.set_model_output(self._model)

if self._compile_loss:
training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
training.compile_loss(self._loss_fn, mode=self._compile_mode, verbose=self._is_rank_zero)

utils.log_rank_zero(self._logger, "Loss is initialized.")

Expand Down Expand Up @@ -586,7 +595,7 @@ def _setup_model(
model = config.instantiate(cfg_model)

if self._compile_model:
training.compile_model(model, verbose=self._is_rank_zero)
training.compile_model(model, mode=self._compile_mode, verbose=self._is_rank_zero)

if self._enable_fp8_training:
# Requires https://github.com/pytorch/pytorch/pull/148922
Expand Down Expand Up @@ -810,6 +819,7 @@ def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:

with self.activations_handling_ctx:
outputs = self._model(**batch)
# print(f"XXX {dist.get_rank()} OUTPUTS:{outputs.shape} {outputs.dtype}")

# post process for third party loss functions
if not isinstance(self._loss_fn, SFTLoss):
Expand All @@ -820,6 +830,7 @@ def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:

# Compute loss
loss = self._loss_fn(outputs, labels)
# print(f"XXX {dist.get_rank()} LOSS:{loss}")

# free logits otherwise it peaks backward memory
del outputs
Expand Down Expand Up @@ -895,6 +906,9 @@ def train(self) -> None:
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
self._dataloader.sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(self._dataloader):
b_tokens = batch["tokens"]
b_labels = batch["labels"]
# print(f"XXX R:{dist.get_rank()} BATCH:{idx} b_labels:{b_labels.shape} b_tokens:{b_tokens.shape}")
# Start tracking CUDA memory for active steps for just the first epoch
if (
self._is_rank_zero
Expand All @@ -916,7 +930,9 @@ def train(self) -> None:

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
# print(f"XXX R:{dist.get_rank()} BATCH:{idx} current_num_tokens:{current_num_tokens}")
current_loss = self._loss_step(batch) * current_num_tokens
# print(f"XXX R:{dist.get_rank()} BATCH:{idx} current_loss:{current_loss}")
running_loss += current_loss

# For optimizer in backward, we need to normalize before calling backward
Expand Down Expand Up @@ -1068,6 +1084,26 @@ def cleanup(self) -> None:
self._metric_logger.close()
destroy_process_group()

# from torch.utils._python_dispatch import TorchDispatchMode
# import torch.utils._pytree as pytree
# from torch._higher_order_ops.flex_attention import flex_attention
#
#
#
# class Mode(TorchDispatchMode):
# def __torch_dispatch__(self, func, types, args=(), kwargs=None):
# r = torch.distributed.get_rank()
# print(f"XXX RANK[{r}] MODE._torch_dispatch_ {func} {types}")
# for a in pytree.tree_leaves(args):
# if issubclass(type(a), torch.Tensor):
# print(f"XXX RANK[{r}] {a.dtype} {a.shape}")
# else:
# print(f"XXX RANK[{r}] {a}")
# return func(*args, **kwargs)
#
# def flex_attention_mode_call(mode, *args, **kwargs):
# return flex_attention(*args, **kwargs)
# flex_attention.py_impl(Mode)(flex_attention_mode_call)

@config.parse
def recipe_main(cfg: DictConfig) -> None:
Expand All @@ -1081,6 +1117,7 @@ def recipe_main(cfg: DictConfig) -> None:
config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)
recipe = FullFinetuneRecipeDistributed(cfg=cfg)
recipe.setup(cfg=cfg)
# with Mode():
recipe.train()
recipe.cleanup()

Expand Down
12 changes: 10 additions & 2 deletions torchtune/models/llama4/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TokenChoiceTopKRouter,
)
from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear
from torchtune.utils._device import has_cuda_capability

"""
Component builders for the Llama4 model.
Expand Down Expand Up @@ -180,6 +181,7 @@ def llama4_decoder(
num_experts: int = 16,
experts_per_token: int = 1,
use_shared_expert: bool = True,
use_grouped_mm: bool = True,
use_qk_norm: bool = True,
moe_every_n_layers: Optional[int] = None,
mlp_hidden_dim: Optional[int] = None,
Expand Down Expand Up @@ -244,6 +246,11 @@ def llama4_decoder(
raise ValueError(
"Must pass local_chunk_size when enabling local chunked attention"
)
if use_grouped_mm and not has_cuda_capability(9, 0):
torchtune.utils.get_logger("WARNING")(
"Failed to use grouped mm, which is only supported on SM90 or later",
)
use_grouped_mm = False
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads

Expand All @@ -263,7 +270,6 @@ def llama4_decoder(
)
layers = []
for i in range(num_layers):

mask_mod = None
if skip_rope_interval is not None and (i + 1) % skip_rope_interval != 0:
mask_mod = partial(
Expand Down Expand Up @@ -300,6 +306,7 @@ def llama4_decoder(
num_experts=num_experts,
experts_per_token=experts_per_token,
use_shared_expert=use_shared_expert,
use_grouped_mm=use_grouped_mm,
)
else:
mlp_layer = llama4_mlp(dim=embed_dim, hidden_dim=mlp_hidden_dim)
Expand Down Expand Up @@ -355,6 +362,7 @@ def llama4_moe(
num_experts: int = 8,
experts_per_token: int = 1,
use_shared_expert: bool = True,
use_grouped_mm: bool = True,
) -> MoE:
"""
Build the MoE layer associated with the Llama model.
Expand Down Expand Up @@ -631,6 +639,7 @@ def lora_llama4_decoder(
raise ValueError(
"Must pass local_chunk_size when enabling local chunked attention"
)

head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
if use_scaled_rope:
Expand All @@ -649,7 +658,6 @@ def lora_llama4_decoder(
)
layers = []
for i in range(num_layers):

mask_mod = None
if skip_rope_interval is not None and (i + 1) % skip_rope_interval != 0:
mask_mod = partial(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/modules/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def compile_flex_attention():
# when compiled. To insulate it from the compiler, we wrap it with
# compiler.disable so that it can be used regardless of whether the model
# is compiled or not, and flex attention always remains compiled.
@torch.compiler.disable(recursive=False)
# @torch.compiler.disable(recursive=False)
def compile_friendly_flex_attention(
q: torch.Tensor,
k: torch.Tensor,
Expand Down
80 changes: 60 additions & 20 deletions torchtune/modules/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from torchtune.modules.peft import AdapterModule


@torch._dynamo.allow_in_graph
def _grouped_mm(x, w, offs):
return torch._grouped_mm(x, w, offs=offs)


class GroupedExperts(nn.Module):
"""This class implements the grouped experts layer used in Mixture of Experts. Each expert
is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202.
Expand All @@ -31,6 +36,7 @@ def __init__(
hidden_dim: int,
num_experts: int = 1,
activation: Callable = F.silu,
use_grouped_mm: bool = False,
):
super().__init__()
self.dim = dim
Expand All @@ -39,6 +45,8 @@ def __init__(
self.down_proj = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
self.up_proj = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
self.act_fn = activation
self.rank = torch.distributed.get_rank()
self.use_grouped_mm = use_grouped_mm

def reset_parameters(self) -> None:
# Default initialization used by torch.nn.Linear
Expand All @@ -50,6 +58,7 @@ def reset_parameters(self) -> None:
# TODO: force no inference mode as a hack to get around
# "Cannot set version_counter for inference tensor"
@torch.inference_mode(mode=False)
@torch._dynamo.disable(recursive=False)
def forward(
self,
x: torch.Tensor,
Expand All @@ -64,28 +73,59 @@ def forward(
Returns:
torch.Tensor: tensor with shape ``(bsz * seq_len * experts_per_token, dim)``
"""

# a tuple of tensors indexed by experts
# each with shape (tokens_per_expert(varying), dim)
x = torch.split(
x,
split_size_or_sections=num_tokens_per_expert.tolist(),
dim=0,
)
out_experts_splits = []
for expert_idx, x_expert in enumerate(x):
w1, w2, w3 = (
self.gate_proj[expert_idx],
self.down_proj[expert_idx],
self.up_proj[expert_idx],
if not self.use_grouped_mm:
# a tuple of tensors indexed by experts
# each with shape (tokens_per_expert(varying), dim)
num_tokens_per_expert_list = num_tokens_per_expert.tolist()
if torch.compiler.is_compiling():
for n in num_tokens_per_expert_list:
torch._check_is_size(n)
x = torch.split(
x,
split_size_or_sections=num_tokens_per_expert_list,
dim=0,
)
h = self.act_fn(torch.matmul(x_expert, w1))
h = h * torch.matmul(x_expert, w3)
h = torch.matmul(h, w2)
# h shape (tokens_per_expert(varying), dim)
out_experts_splits.append(h)
out = torch.cat(out_experts_splits, dim=0)
out_experts_splits = []
for expert_idx, x_expert in enumerate(x):
w1, w2, w3 = (
self.gate_proj[expert_idx],
self.down_proj[expert_idx],
self.up_proj[expert_idx],
)
h = self.act_fn(torch.matmul(x_expert, w1))
h = h * torch.matmul(x_expert, w3)
h = torch.matmul(h, w2)
# h shape (tokens_per_expert(varying), dim)
out_experts_splits.append(h)
out = torch.cat(out_experts_splits, dim=0)

return out

# grouped mm implementation
if num_tokens_per_expert is not None:
# https://github.com/pytorch/pytorch/pull/150374
# NOTE: torch._gouped_mm requires bf16 dtypes
# and shapes to be multiple of 8
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
# grouped mm between a 2D tensor and a 3D tensor
assert x.dim() == 2
else:
offsets = None
# fall back to regular bmm between 3D tensors
assert x.dim() == 3

w1, w2, w3 = (
self.gate_proj,
self.down_proj,
self.up_proj,
)
assert (
x.dtype == w1.dtype == w2.dtype == w3.dtype == torch.bfloat16
), "torch._grouped_mm only supports bf16 dtypes"
h = F.silu(_grouped_mm(x, w1, offs=offsets))
h = h * _grouped_mm(x, w3, offs=offsets)
out = _grouped_mm(h, w2, offs=offsets)
out[offsets[-1] :].zero_()
return out


Expand Down
Loading
Loading