From 5c971f4b132a230ca096a2c96d15610afe55ee9f Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 8 Sep 2025 09:06:19 +0000 Subject: [PATCH 1/2] Create _testing/models sub-folder and move llama3 there I'm going to be starting adding tess specific to llama3, so having the model easily accessible without copying it would be good. Plus, with the pipeline example in https://github.com/meta-pytorch/autoparallel/pull/121 we already have the implementation of llama3 which is duplicated, so let's keep it in a sub-folder --- autoparallel/_testing/models/llama3.py | 546 +++++++++++++++++++++++++ examples/example_llama3.py | 545 +----------------------- 2 files changed, 547 insertions(+), 544 deletions(-) create mode 100644 autoparallel/_testing/models/llama3.py diff --git a/autoparallel/_testing/models/llama3.py b/autoparallel/_testing/models/llama3.py new file mode 100644 index 00000000..aa196e11 --- /dev/null +++ b/autoparallel/_testing/models/llama3.py @@ -0,0 +1,546 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import ClassVar + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.attention import SDPBackend, sdpa_kernel + + +def has_cuda_capability(major: int, minor: int) -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + major, + minor, + ) + + +class ScaledDotProductAttention(torch.nn.Module): + backends: ClassVar[list[SDPBackend]] = [] + + def __init__(self, attn_mask_type: str) -> None: + super().__init__() + if attn_mask_type != "causal": + raise ValueError( + "TorchTitan with SDPA currently only supports causal mask." + ) + + ScaledDotProductAttention._init_backend() + + @classmethod + def _init_backend(cls) -> None: + if cls.backends: + return + + # Add CuDNN on B200 w/ highest priority + cls.backends = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + if has_cuda_capability(10, 0): + cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + assert self.backends, "SDPA Backends should not be empty." + with sdpa_kernel(self.backends, set_priority=True): + return F.scaled_dot_product_attention(q, k, v, is_causal=True) + + +def build_attention( + use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None +): + if use_flex_attn: + raise NotImplementedError() + # return FlexAttention(attn_mask_type, fixed_block_size) + else: + if fixed_block_size is not None: + raise ValueError( + "TorchTitan with SDPA currently does not support fixed_block_size." + ) + if attn_mask_type != "causal": + raise ValueError( + "TorchTitan with SDPA currently only supports causal mask." + ) + return ScaledDotProductAttention(attn_mask_type) + + +@dataclass +class TransformerModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: int | None = None + vocab_size: int = 64000 # -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: float | None = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + + use_flex_attn: bool = False + attn_mask_type: str = "causal" + eos_id: int = 0 + + def update_from_config(self, job_config, tokenizer) -> None: + self.vocab_size = tokenizer.n_words + self.max_seq_len = job_config.training.seq_len + self.eos_id = tokenizer.eos_id + + if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn: + raise ValueError( + "FlexAttention is not compatible with selective AC yet. " + "See https://github.com/pytorch/pytorch/issues/147879" + ) + + if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + raise ValueError( + "FlexAttention is not compatible with CP yet. " + "We are still working on this." + ) + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + nparams = sum(p.numel() for p in model.parameters()) + nparams_embedding = sum( + sum(p.numel() for p in m.parameters()) + for m in model.children() + if isinstance(m, nn.Embedding) + ) + + l, h, q, t = ( + self.n_layers, + self.n_heads, + self.dim // self.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t + + return nparams, num_flops_per_token + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + # TODO: uncomment + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + output = self.sdpa(xq, xk, xv) + + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float | None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + return + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module): + """ + Transformer Module + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + model_args (TransformerModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + + def init_weights( + self, + buffer_device: torch.device | None = None, + ): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + buffer_device = buffer_device or self.freqs_cis.device # type: ignore + with torch.device(buffer_device): + self.freqs_cis = self._precompute_freqs_cis() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() # type: ignore + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. + If pipeline parallelism is enabled, this will be the input token indices + for the ranks on the first pipeline stage. This will be the activation of the + previous pipeline stage if the current rank is not on the first stage. + input_batch (torch.Tensor): The input batch read from the dataloader. + This will always be the input batch regardless of the pipeline stage. + This field is required for non-first PP stages to perform document + masking attention (to analyze the boundary of the document). + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + return output diff --git a/examples/example_llama3.py b/examples/example_llama3.py index 97dcc4b5..eeff9ffb 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -4,558 +4,15 @@ # LICENSE file in the root directory of this source tree. import time -from dataclasses import dataclass -from typing import ClassVar import torch -import torch.nn.functional as F -from torch import nn from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Partial, Replicate, Shard -from torch.nn.attention import SDPBackend, sdpa_kernel from torch.testing._internal.distributed.fake_pg import FakeStore +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs from autoparallel.api import AutoParallel - -def has_cuda_capability(major: int, minor: int) -> bool: - return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( - major, - minor, - ) - - -class ScaledDotProductAttention(torch.nn.Module): - backends: ClassVar[list[SDPBackend]] = [] - - def __init__(self, attn_mask_type: str) -> None: - super().__init__() - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - - ScaledDotProductAttention._init_backend() - - @classmethod - def _init_backend(cls) -> None: - if cls.backends: - return - - # Add CuDNN on B200 w/ highest priority - cls.backends = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, - ] - if has_cuda_capability(10, 0): - cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) - - def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor - ) -> torch.Tensor: - assert self.backends, "SDPA Backends should not be empty." - with sdpa_kernel(self.backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, is_causal=True) - - -def build_attention( - use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None -): - if use_flex_attn: - raise NotImplementedError() - # return FlexAttention(attn_mask_type, fixed_block_size) - else: - if fixed_block_size is not None: - raise ValueError( - "TorchTitan with SDPA currently does not support fixed_block_size." - ) - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - return ScaledDotProductAttention(attn_mask_type) - - -@dataclass -class TransformerModelArgs: - dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 - n_kv_heads: int | None = None - vocab_size: int = 64000 # -1 # defined later by tokenizer - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: float | None = None - norm_eps: float = 1e-5 - rope_theta: float = 10000 - - max_seq_len: int = 2048 - # If `True`, then each transformer block init uses its layer ID, and if - # `False`, each uses the total number of transformer blocks - depth_init: bool = True - - use_flex_attn: bool = False - attn_mask_type: str = "causal" - eos_id: int = 0 - - def update_from_config(self, job_config, tokenizer) -> None: - self.vocab_size = tokenizer.n_words - self.max_seq_len = job_config.training.seq_len - self.eos_id = tokenizer.eos_id - - if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn: - raise ValueError( - "FlexAttention is not compatible with selective AC yet. " - "See https://github.com/pytorch/pytorch/issues/147879" - ) - - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise ValueError( - "FlexAttention is not compatible with CP yet. " - "We are still working on this." - ) - - def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: - nparams = sum(p.numel() for p in model.parameters()) - nparams_embedding = sum( - sum(p.numel() for p in m.parameters()) - for m in model.children() - if isinstance(m, nn.Embedding) - ) - - l, h, q, t = ( - self.n_layers, - self.n_heads, - self.dim // self.n_heads, - seq_len, - ) - # Reasoning behind the factor of 12 for the self-attention part of the formula: - # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) - # 2. the flash attention does 1 more matmul recomputation in the backward - # but recomputation should not be counted in calculating MFU (+0) - # 3. each matmul performs 1 multiplication and 1 addition (*2) - # 4. we follow the convention and do not account for sparsity in causal attention - num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t - - return nparams, num_flops_per_token - - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' - and the end index 'end'. The 'theta' parameter scales the frequencies. - The returned tensor contains complex values in complex64 data type. - - Args: - dim (int): Dimension of the frequency tensor. - end (int): End index for precomputing frequencies. - theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0. - - Returns: - torch.Tensor: Precomputed frequency tensor with complex exponentials. - """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) - freqs = torch.outer(t, freqs).float() - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """ - Reshape frequency tensor for broadcasting it with another tensor. - - This function reshapes the frequency tensor to have the same shape as the target tensor 'x' - for the purpose of broadcasting the frequency tensor during element-wise operations. - - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), - and the first seqlen elements will be sliced, but dim must match x. - - Args: - freqs_cis (torch.Tensor): Frequency tensor to be reshaped. - x (torch.Tensor): Target tensor for broadcasting compatibility. - - Returns: - torch.Tensor: Reshaped frequency tensor. - """ - ndim = x.ndim - assert ndim > 1 - seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary embeddings to input tensors using the given frequency tensor. - - This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided - frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are - returned as real tensors. - - Args: - xq (torch.Tensor): Query tensor to apply rotary embeddings. - xk (torch.Tensor): Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. - - Returns: - tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - """ - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - torch.unsqueeze(x, dim=3) - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - -class Attention(nn.Module): - """ - Multi-head attention module. - - Args: - model_args (TransformerModelArgs): Model configuration arguments. - - Attributes: - n_kv_heads (int): Number of key and value heads. - n_heads (int): Number of query heads. - n_rep (int): Number of repetitions for local heads. - head_dim (int): Dimension size of each attention head. - wq (Linear): Linear transformation for queries. - wk (Linear): Linear transformation for keys. - wv (Linear): Linear transformation for values. - wo (Linear): Linear transformation for output. - - """ - - def __init__(self, model_args: TransformerModelArgs): - super().__init__() - self.n_heads = model_args.n_heads - self.n_kv_heads = ( - model_args.n_heads - if model_args.n_kv_heads is None - else model_args.n_kv_heads - ) - self.n_rep = self.n_heads // self.n_kv_heads - self.head_dim = model_args.dim // model_args.n_heads - - self.wq = nn.Linear( - model_args.dim, model_args.n_heads * self.head_dim, bias=False - ) - self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear( - model_args.n_heads * self.head_dim, model_args.dim, bias=False - ) - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) - - def init_weights(self, init_std: float): - for linear in (self.wq, self.wk, self.wv): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - ): - """ - Forward pass of the attention module. - - Args: - x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed frequency tensor. - - Returns: - torch.Tensor: Output tensor after attention. - - """ - - bs, seqlen, _ = x.shape - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual - # local heads from sizes of xq, xk, and xv as TP may have sharded them - # after the above linear ops. - xq = xq.view(bs, seqlen, -1, self.head_dim) - xk = xk.view(bs, seqlen, -1, self.head_dim) - xv = xv.view(bs, seqlen, -1, self.head_dim) - - # TODO: uncomment - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - - # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - - output = self.sdpa(xq, xk, xv) - - output = output.transpose( - 1, 2 - ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - output = output.view(bs, seqlen, -1) - return self.wo(output) - - -class FeedForward(nn.Module): - """ - FeedForward module - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. - - Attributes: - w1 (Linear): Linear transformation for the first layer. - w2 (Linear): Linear transformation for the second layer. - w3 (Linear): Linear transformation for the third layer. - - """ - - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: float | None, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) - for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) - - -class TransformerBlock(nn.Module): - """ - TransformerBlock Module - - Args: - layer_id (int): Identifier for the layer. - model_args (TransformerModelArgs): Model configuration arguments. - - Attributes: - n_heads (int): Number of attention heads. - dim (int): Dimension size of the model. - head_dim (int): Dimension size of each attention head. - attention (Attention): Attention module. - feed_forward (FeedForward): FeedForward module. - layer_id (int): Identifier for the layer. - attention_norm (RMSNorm): Layer normalization for attention output. - ffn_norm (RMSNorm): Layer normalization for feedforward output. - - """ - - def __init__(self, layer_id: int, model_args: TransformerModelArgs): - super().__init__() - self.n_heads = model_args.n_heads - self.dim = model_args.dim - - self.attention = Attention(model_args) - self.feed_forward = FeedForward( - dim=model_args.dim, - hidden_dim=4 * model_args.dim, - multiple_of=model_args.multiple_of, - ffn_dim_multiplier=model_args.ffn_dim_multiplier, - ) - self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - - if model_args.depth_init: - self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 - else: - self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - ): - """ - Perform a forward pass through the TransformerBlock. - - Args: - x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. - - Returns: - torch.Tensor: Output tensor after applying attention and feedforward layers. - - """ - h = x + self.attention(self.attention_norm(x), freqs_cis) - out = h + self.feed_forward(self.ffn_norm(h)) - return out - - def init_weights(self): - return - for norm in (self.attention_norm, self.ffn_norm): - norm.reset_parameters() - self.attention.init_weights(self.weight_init_std) - self.feed_forward.init_weights(self.weight_init_std) - - -class Transformer(nn.Module): - """ - Transformer Module - - Args: - model_args (TransformerModelArgs): Model configuration arguments. - - Attributes: - model_args (TransformerModelArgs): Model configuration arguments. - vocab_size (int): Vocabulary size. - n_layers (int): Number of layers in the model. - tok_embeddings (ParallelEmbedding): Token embeddings. - layers (torch.nn.ModuleList): List of Transformer blocks. - norm (RMSNorm): Layer normalization for the model output. - output (ColumnParallelLinear): Linear layer for final output. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. - - """ - - def __init__(self, model_args: TransformerModelArgs): - super().__init__() - self.model_args = model_args - self.vocab_size = model_args.vocab_size - self.n_layers = model_args.n_layers - self.eos_id = model_args.eos_id - - self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) - - # TODO persistent should be set to false, since this buffer can be recomputed. - # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, - # compile or pipeline-tracer will not correctly handle non-persistent buffers, - # so we need to fix that. (2) if we initialize pipeline-parallel models from - # a seed checkpoint rather than calling init_weights, we need freqs_cis to be - # initialized by the checkpoint, or we need to add a separate initializer for - # just the non-persistent buffers that is called after loading checkpoints. - self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) - - self.layers = torch.nn.ModuleDict() - for layer_id in range(model_args.n_layers): - self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) - self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) - - def init_weights( - self, - buffer_device: torch.device | None = None, - ): - """ - [Note: On ``init_weights`` vs. ``reset_parameters``] - Modules may define ``reset_parameters`` to initialize parameter values. - ``reset_parameters`` is meant to only initialize directly owned - parameters/buffers, not those of their child modules, and it can be - used to give the initial values for these tensors. - Separately, users may want custom initialization for their modules, - different from that in ``reset_parameters``. For this, we define - ``init_weights``. We only call it in the constructor of this - ``Transformer`` root module to avoid reinitializing tensors. - """ - buffer_device = buffer_device or self.freqs_cis.device # type: ignore - with torch.device(buffer_device): - self.freqs_cis = self._precompute_freqs_cis() - if self.tok_embeddings is not None: - nn.init.normal_(self.tok_embeddings.weight) - for layer in self.layers.values(): - if layer is not None: - layer.init_weights() - if self.norm is not None: - self.norm.reset_parameters() - final_out_std = self.model_args.dim**-0.5 - cutoff_factor = 3 - if self.output is not None: - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=final_out_std, - a=-cutoff_factor * final_out_std, - b=cutoff_factor * final_out_std, - ) - - def _precompute_freqs_cis(self) -> torch.Tensor: - return precompute_freqs_cis( - self.model_args.dim // self.model_args.n_heads, - # Need to compute until at least the max token limit for generation - # TODO: explain in docs/composability.md why we removed the 2x - # relaxing in our CP enablement PR - self.model_args.max_seq_len, - self.model_args.rope_theta, - ) - - def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None): - """ - Perform a forward pass through the Transformer model. - - Args: - tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. - If pipeline parallelism is enabled, this will be the input token indices - for the ranks on the first pipeline stage. This will be the activation of the - previous pipeline stage if the current rank is not on the first stage. - input_batch (torch.Tensor): The input batch read from the dataloader. - This will always be the input batch regardless of the pipeline stage. - This field is required for non-first PP stages to perform document - masking attention (to analyze the boundary of the document). - - Returns: - torch.Tensor: Output logits after applying the Transformer model. - - """ - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages - h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - - for layer in self.layers.values(): - h = layer(h, self.freqs_cis) - - h = self.norm(h) if self.norm else h - output = self.output(h) if self.output else h - return output - - -# ============================================================== -# AutoParallel code starts here -# ============================================================== - world_size = 64 fake_store = FakeStore() From f0317a572fd94ea59ad353c4375c050793ad33a9 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 8 Sep 2025 09:40:46 +0000 Subject: [PATCH 2/2] py39 typing --- autoparallel/_testing/models/llama3.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/autoparallel/_testing/models/llama3.py b/autoparallel/_testing/models/llama3.py index aa196e11..9d349e1a 100644 --- a/autoparallel/_testing/models/llama3.py +++ b/autoparallel/_testing/models/llama3.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import ClassVar +from typing import ClassVar, Optional import torch import torch.nn.functional as F @@ -54,7 +54,7 @@ def forward( def build_attention( - use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None + use_flex_attn: bool, attn_mask_type: str, fixed_block_size: Optional[int] = None ): if use_flex_attn: raise NotImplementedError() @@ -76,10 +76,10 @@ class TransformerModelArgs: dim: int = 4096 n_layers: int = 32 n_heads: int = 32 - n_kv_heads: int | None = None + n_kv_heads: Optional[int] = None vocab_size: int = 64000 # -1 # defined later by tokenizer multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: float | None = None + ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 rope_theta: float = 10000 @@ -338,7 +338,7 @@ def __init__( dim: int, hidden_dim: int, multiple_of: int, - ffn_dim_multiplier: float | None, + ffn_dim_multiplier: Optional[float], ): super().__init__() hidden_dim = int(2 * hidden_dim / 3) @@ -473,7 +473,7 @@ def __init__(self, model_args: TransformerModelArgs): def init_weights( self, - buffer_device: torch.device | None = None, + buffer_device: Optional[torch.device] = None, ): """ [Note: On ``init_weights`` vs. ``reset_parameters``] @@ -517,7 +517,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_theta, ) - def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None): + def forward(self, tokens: torch.Tensor, input_batch: Optional[torch.Tensor] = None): """ Perform a forward pass through the Transformer model.