diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index dc75004af0..4b20b1316c 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -1,7 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import Any, Callable import torch @@ -176,6 +176,20 @@ description=_attention_implementations_description, ) +_flex_attention_mods_description = ( + """The flex_attention_mods registry is used to register classes that implement flex attention mods. + + One example is 'CausalMaskMod'. See flex_attn_mods.py for examples. + """ +) +flex_attention_mods = create_registry( + 'llmfoundry', + 'flex_attention_mods', + generic_type=type[Any], + entry_points=True, + description=_flex_attention_mods_description, +) + _param_init_fns_description = ( """The param_init_fns registry is used to register functions that initialize parameters. @@ -231,5 +245,6 @@ 'ffns_with_megablocks', 'attention_classes', 'attention_implementations', + 'flex_attention_mods', 'fcs', ] diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index a76c135f7e..094b661212 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -17,9 +17,16 @@ from llmfoundry.layers_registry import ( attention_classes, attention_implementations, + flex_attention_mods, +) +from llmfoundry.models.layers.flex_attn_utils import ( + FLEX_ATTN_COMPILE, + generate_block_mask, + generate_score_mod, ) from llmfoundry.models.layers.layer_builders import build_fc, build_norm from llmfoundry.models.utils.config_defaults import fc_type_defaults +from llmfoundry.utils.warnings import experimental_function __all__ = [ 'scaled_multihead_dot_product_attention', @@ -167,11 +174,6 @@ def scaled_multihead_dot_product_attention( attn_weight = q.matmul(k) * softmax_scale - if attn_logit_softcapping is not None: - attn_weight = attn_logit_softcapping * torch.tanh( - attn_weight / attn_logit_softcapping, - ) - if attn_bias is not None: # clamp to 0 necessary for torch 2.0 compile() _s_q = max(0, attn_bias.size(2) - s_q) @@ -185,6 +187,11 @@ def scaled_multihead_dot_product_attention( ) attn_weight = attn_weight + attn_bias + if attn_logit_softcapping is not None: + attn_weight = attn_logit_softcapping * torch.tanh( + attn_weight / attn_logit_softcapping, + ) + min_val = torch.finfo(q.dtype).min if key_padding_mask is not None: @@ -445,6 +452,166 @@ def flash_attn_fn( return output, None, past_key_value +@experimental_function('Flex Attention') +def flex_attn_fn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + n_heads: int, + kv_n_heads: int, + compiled_flex_attention: Any, + compiled_create_block_mask: Any, + sequence_id_info: Optional[dict[str, torch.Tensor]], + flex_attn_compile: bool, + flex_attn_mod_list: Optional[list[dict[str, Any]]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + softmax_scale: Optional[float] = None, + attn_bias: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + dropout_p: float = 0.0, + training: bool = False, + needs_weights: bool = False, + should_repeat_kv_for_gqa: Optional[bool] = True, + sliding_window_size: int = -1, + alibi_slopes: Optional[torch.Tensor] = None, + attn_logit_softcapping: Optional[float] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, + torch.Tensor]]]: + del training, should_repeat_kv_for_gqa + if attn_bias is not None: + raise ValueError('attn_bias should be None for flex attn.') + if key_padding_mask is not None: + raise ValueError( + 'key_padding_mask should be None for flex attn. Instead, any padding information should be sent through sequence_id_info.', + ) + if dropout_p > 0.0: + raise NotImplementedError(f'dropout not implemented for flex attn.') + if needs_weights: + raise NotImplementedError( + f'needs_weights not implemented for flex attn.', + ) + + check_valid_inputs(query, key, value) + assert key.shape[1] == value.shape[1] + assert query.shape[1] == key.shape[1] + query_offset = torch.tensor(0, device=query.device) + if past_key_value is not None: + if len(past_key_value) != 0: + assert past_key_value[0].shape[1] == past_key_value[1].shape[1] + query_offset += past_key_value[0].shape[1] + key = torch.cat([past_key_value[0], key], dim=1) + value = torch.cat([past_key_value[1], value], dim=1) + + past_key_value = (key, value) + + enable_gqa = (n_heads != kv_n_heads) + query = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) + key = rearrange(key, 'b s (h d) -> b h s d', h=kv_n_heads) + value = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) + + def _check_mod_list(mod_list: list[dict[str, Any]], name: str): + for mod in mod_list: + if mod['mod_name'] == name: + raise ValueError( + f'{name} mod should not be defined through flex attention config.', + ) + + flex_attn_mod_list = copy.deepcopy( + flex_attn_mod_list, + ) if flex_attn_mod_list is not None else [] + if is_causal: + _check_mod_list(flex_attn_mod_list, 'causal_mask') + flex_attn_mod_list.append({'mod_name': 'causal_mask', 'mod_kwargs': {}}) + if sliding_window_size != -1: + flex_attn_mod_list.append({ + 'mod_name': 'sliding_window_mask', + 'mod_kwargs': { + 'sliding_window_size': + torch.tensor(sliding_window_size, device=query.device), + }, + }) + if sequence_id_info is not None and 'sequence_id' in sequence_id_info and sequence_id_info[ + 'sequence_id'] is not None: + _check_mod_list(flex_attn_mod_list, 'sequence_id_mask') + flex_attn_mod_list.append({ + 'mod_name': 'sequence_id_mask', + 'mod_kwargs': {}, + }) + + if sequence_id_info is not None and 'attention_mask' in sequence_id_info and sequence_id_info[ + 'attention_mask'] is not None: + _check_mod_list(flex_attn_mod_list, 'attention_mask') + flex_attn_mod_list.append({ + 'mod_name': 'attention_mask', + 'mod_kwargs': {}, + }) + + if alibi_slopes is not None: + _check_mod_list(flex_attn_mod_list, 'alibi_score_mod') + flex_attn_mod_list.append({ + 'mod_name': 'alibi_score_mod', + 'mod_kwargs': { + 'alibi_slopes': alibi_slopes, + }, + }) + if attn_logit_softcapping is not None: + if int(attn_logit_softcapping) != attn_logit_softcapping: + raise ValueError( + f'FlexAttention does not support attn_logit_softcapping with float softcap temperature. Got {attn_logit_softcapping=}. Please set consider rounding it to the closest integer.', + ) + attn_logit_softcapping = int(attn_logit_softcapping) + _check_mod_list(flex_attn_mod_list, 'softcap_score_mod') + flex_attn_mod_list.append({ + 'mod_name': 'softcap_score_mod', + 'mod_kwargs': { + 'attn_logit_softcapping': + torch.tensor(attn_logit_softcapping, device=query.device), + }, + }) + + flex_attn_mod_list = [ + flex_attention_mods.get(mod['mod_name'])(**mod['mod_kwargs']) + for mod in flex_attn_mod_list + ] + block_mask_list = [ + mod for mod in flex_attn_mod_list + if mod.mod_type == 'mask' # type: ignore + ] + score_mod_list = [ + mod for mod in flex_attn_mod_list + if mod.mod_type == 'score' # type: ignore + ] + + block_mask = generate_block_mask( + Q_LEN=query.shape[2], + KV_LEN=key.shape[2], + B=query.shape[0], + block_mask_list=block_mask_list, # type: ignore + compiled_create_block_mask=compiled_create_block_mask, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + flex_attn_compile=flex_attn_compile, + ) + score_mod = generate_score_mod( + score_mod_list=score_mod_list, # type: ignore + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ) + + output = compiled_flex_attention( + query, + key, + value, + score_mod=score_mod, + block_mask=block_mask, + scale=softmax_scale, + enable_gqa=enable_gqa, + ) + output = rearrange(output, 'b h s d -> b s (h d)') + return output, None, past_key_value + + @attention_classes.register_class('grouped_query_attention') class GroupedQueryAttention(nn.Module): """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA). @@ -479,6 +646,8 @@ def __init__( attn_logit_softcapping: Optional[float] = None, attn_temperature_tuning: Optional[dict] = None, kv_dim: Optional[int] = None, + flex_attn_mod_list: Optional[list[dict[str, Any]]] = None, + flex_attn_compile: bool = FLEX_ATTN_COMPILE, nope: bool = False, ): super().__init__() @@ -607,6 +776,10 @@ def __init__( ) self.out_proj._is_residual = True + if self.attn_impl == 'flex': + self.flex_attn_mod_list = flex_attn_mod_list + self.flex_attn_compile = flex_attn_compile + def forward( self, x: torch.Tensor, @@ -621,6 +794,7 @@ def forward( prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, key_value_states: Optional[torch.Tensor] = None, + flex_attn_kwargs: Optional[dict[str, Any]] = None, pos_id_within_seq: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: @@ -647,6 +821,7 @@ def forward( attention_mask, alibi_slopes, flash_attn_padding_info, + flex_attn_kwargs, ) context, attn_weights, past_key_value = self.attn_fn( @@ -857,6 +1032,7 @@ def get_implementation_specific_args( attention_mask: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + flex_attn_kwargs: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: """Returns attention implementation specific args. @@ -864,6 +1040,7 @@ def get_implementation_specific_args( attention_mask (Optional[torch.Tensor]): The attention mask. alibi_slopes (Optional[torch.Tensor]): The alibi slopes. flash_attn_padding_info (Optional[dict[str, torch.Tensor]]): The padding information, only required for flash attention. + flex_attn_kwargs (Optional[dict[str, Any]]): The extra flex attn kwargs, sent from the model, includes seq id transforms and compiled flex attention functions. Returns: extra_attn_kwargs (dict[str, Any]): Implementation specific args. @@ -877,6 +1054,24 @@ def get_implementation_specific_args( 'flash_attn_padding_info': flash_attn_padding_info, 'key_padding_mask': None, } + elif self.attn_impl == 'flex': + if flex_attn_kwargs is None: + raise ValueError( + 'flex_attn_kwargs must be provided for flex attention.', + ) + if 'sequence_id_info' not in flex_attn_kwargs: + raise ValueError( + 'sequence_id_info must be provided in flex_attn_kwargs.', + ) + flex_attn_kwargs['sequence_id_info']['attention_mask' + ] = attention_mask + extra_attn_kwargs = { + 'alibi_slopes': alibi_slopes, + 'key_padding_mask': None, + 'flex_attn_mod_list': self.flex_attn_mod_list, + 'flex_attn_compile': self.flex_attn_compile, + **flex_attn_kwargs, + } else: extra_attn_kwargs = {'key_padding_mask': attention_mask} return extra_attn_kwargs @@ -910,6 +1105,8 @@ def __init__( attn_logit_softcapping: Optional[float] = None, attn_temperature_tuning: Optional[dict] = None, kv_dim: Optional[int] = None, + flex_attn_mod_list: Optional[list[dict[str, Any]]] = None, + flex_attn_compile: bool = FLEX_ATTN_COMPILE, nope: bool = False, ): super().__init__( @@ -933,6 +1130,8 @@ def __init__( attn_logit_softcapping=attn_logit_softcapping, attn_temperature_tuning=attn_temperature_tuning, kv_dim=kv_dim, + flex_attn_mod_list=flex_attn_mod_list, + flex_attn_compile=flex_attn_compile, nope=nope, ) @@ -965,6 +1164,8 @@ def __init__( attn_logit_softcapping: Optional[float] = None, attn_temperature_tuning: Optional[dict] = None, kv_dim: Optional[int] = None, + flex_attn_mod_list: Optional[list[dict[str, Any]]] = None, + flex_attn_compile: bool = FLEX_ATTN_COMPILE, nope: bool = False, ): super().__init__( @@ -988,6 +1189,8 @@ def __init__( attn_logit_softcapping=attn_logit_softcapping, attn_temperature_tuning=attn_temperature_tuning, kv_dim=kv_dim, + flex_attn_mod_list=flex_attn_mod_list, + flex_attn_compile=flex_attn_compile, nope=nope, ) @@ -1000,7 +1203,7 @@ def attn_bias_shape( causal: bool, use_sequence_id: bool, ) -> Optional[tuple[int, int, int, int]]: - if attn_impl == 'flash': + if attn_impl == 'flash' or attn_impl == 'flex': return None elif attn_impl == 'torch': if alibi: @@ -1096,3 +1299,4 @@ def build_alibi_bias( 'torch', func=scaled_multihead_dot_product_attention, ) +attention_implementations.register('flex', func=flex_attn_fn) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 427f0548e1..6fc03c8a68 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -165,6 +165,7 @@ def forward( prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, key_value_states: Optional[torch.Tensor] = None, + flex_attn_kwargs: Optional[dict[str, Any]] = None, pos_id_within_seq: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: @@ -187,6 +188,7 @@ def forward( output_attentions=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + flex_attn_kwargs=flex_attn_kwargs, **extra_kwargs, ) else: @@ -201,6 +203,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + flex_attn_kwargs=flex_attn_kwargs, **extra_kwargs, ) x = x + self.resid_attn_dropout(b) @@ -338,6 +341,7 @@ def forward( prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, key_value_states: Optional[torch.Tensor] = None, + flex_attn_kwargs: Optional[dict[str, Any]] = None, pos_id_within_seq: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: @@ -360,6 +364,7 @@ def forward( needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, + flex_attn_kwargs=flex_attn_kwargs, **extra_kwargs, ) x = x + self.resid_attn_dropout(b) diff --git a/llmfoundry/models/layers/flex_attn_utils.py b/llmfoundry/models/layers/flex_attn_utils.py new file mode 100644 index 0000000000..a84ddf321e --- /dev/null +++ b/llmfoundry/models/layers/flex_attn_utils.py @@ -0,0 +1,332 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import warnings +from abc import ABC +from functools import partial +from typing import Any, Optional + +import torch +from packaging import version +from torch.nn.attention.flex_attention import ( + _DEFAULT_SPARSE_BLOCK_SIZE, + _score_mod_signature, + and_masks, +) + +FLEX_ATTN_COMPILE = version.parse( + torch.__version__.split('.dev')[0], +) >= version.parse('2.6.0') + +from llmfoundry.layers_registry import flex_attention_mods + + +class FlexAttentionMod(ABC): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, query_offset, b, h, q_idx, kv_idx + raise NotImplementedError + + def _score_mod_fn( + self, + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, query_offset, score, b, h, q_idx, kv_idx + raise NotImplementedError + + def __init__(self, mod_type: str) -> None: + assert mod_type in ['mask', 'score'] + self.mod_type = mod_type + self.mod_fn = self._mask_mod_fn if mod_type == 'mask' else self._score_mod_fn + + +@flex_attention_mods.register('causal_mask') +class CausalMaskMod(FlexAttentionMod): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, b, h + q_idx = q_idx + query_offset + return q_idx >= kv_idx + + def __init__(self) -> None: + super().__init__(mod_type='mask') + + +@flex_attention_mods.register('sliding_window_mask') +class SlidingWindowMaskMod(FlexAttentionMod): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, b, h + q_idx = q_idx + query_offset + return torch.abs(q_idx - kv_idx) <= self.sliding_window_size + + def __init__(self, sliding_window_size: torch.Tensor) -> None: + super().__init__(mod_type='mask') + self.sliding_window_size = sliding_window_size + + +@flex_attention_mods.register('sequence_id_mask') +class SequenceIdMaskMod(FlexAttentionMod): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del h + q_idx = q_idx + query_offset + if sequence_id_info is None: + raise ValueError( + 'sequence_id_info is required for SequenceIdMaskMod', + ) + sequence_id = sequence_id_info['sequence_id'] + # Check if the query and key belong to the same sequence and the query token is not a padding token. + return (sequence_id[b, q_idx] == sequence_id[b, kv_idx]) + + def __init__(self) -> None: + super().__init__(mod_type='mask') + + +@flex_attention_mods.register('attention_mask') +class AttentionMaskMod(FlexAttentionMod): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del h, q_idx, query_offset + if sequence_id_info is None: + raise ValueError( + 'sequence_id_info is required for SequenceIdMaskMod', + ) + attention_mask = sequence_id_info['attention_mask'] + # Check if the query and key belong to the same sequence and the query token is not a padding token. + return attention_mask[b, kv_idx] + + def __init__(self) -> None: + super().__init__(mod_type='mask') + + +@flex_attention_mods.register('local_global_mask') +class LocalGlobalMaskMod(FlexAttentionMod): + + def _mask_mod_fn( + self, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del h + q_idx = q_idx + query_offset + if sequence_id_info is None: + raise ValueError( + 'sequence_id_info is required for LocalGlobalMaskMod', + ) + pos_in_seq = sequence_id_info['pos_in_seq'] + # Check if the query and key belong to the same sequence and the query token is not a padding token. + + if pos_in_seq is not None: + global_window_mask = ( + pos_in_seq[b, kv_idx] <= self.global_window_size + ) + else: + global_window_mask = (kv_idx <= self.global_window_size) + sliding_window_mask = (q_idx - kv_idx <= self.sliding_window_size) + + return global_window_mask | sliding_window_mask + + def __init__( + self, + sliding_window_size: int, + global_window_size: int, + ) -> None: + super().__init__(mod_type='mask') + self.sliding_window_size = sliding_window_size + self.global_window_size = global_window_size + + +@flex_attention_mods.register('alibi_score_mod') +class AlibiScoreMod(FlexAttentionMod): + + def _score_mod_fn( + self, + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, b + q_idx = q_idx + query_offset + bias = -self.alibi_slopes[h] * torch.abs(q_idx - kv_idx) + return score + bias + + def __init__(self, alibi_slopes: torch.Tensor) -> None: + super().__init__(mod_type='score') + self.alibi_slopes = alibi_slopes + + +@flex_attention_mods.register('softcap_score_mod') +class SoftcapScoreMod(FlexAttentionMod): + + def _score_mod_fn( + self, + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + ) -> torch.Tensor: + del sequence_id_info, query_offset, b, h, q_idx, kv_idx + return self.attn_logit_softcapping * torch.tanh( + score / self.attn_logit_softcapping, + ) + + def __init__(self, attn_logit_softcapping: torch.Tensor) -> None: + super().__init__(mod_type='score') + self.attn_logit_softcapping = attn_logit_softcapping + + +def generate_block_mask( + Q_LEN: int, + KV_LEN: int, + B: int, + block_mask_list: Optional[list[FlexAttentionMod]], + compiled_create_block_mask: Any, + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], + flex_attn_compile: bool, +): + if block_mask_list is None: + return None + + block_mask_fn = None + for i, block_mask in enumerate(block_mask_list): + if i == 0: + block_mask_fn = partial( + block_mask.mod_fn, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ) + else: + block_mask_fn = and_masks( + block_mask_fn, # type: ignore + partial( + block_mask.mod_fn, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ), + ) + + extra_args = {} + if (Q_LEN % _DEFAULT_SPARSE_BLOCK_SIZE != + 0) or (KV_LEN % _DEFAULT_SPARSE_BLOCK_SIZE != 0): + if flex_attn_compile: + warnings.warn( + f'Q_LEN and KV_LEN must be divisible by {_DEFAULT_SPARSE_BLOCK_SIZE}. The results might be incorrect.', + ) + else: + extra_args['BLOCK_SIZE'] = (Q_LEN, KV_LEN) + block_mask = compiled_create_block_mask( + block_mask_fn, + B=B, + H=None, # Setting this to None speeds up block mask generation, but this means the mask has to be the same across all heads. + Q_LEN=Q_LEN, + KV_LEN=KV_LEN, + **extra_args, + ) + + return block_mask + + +def generate_score_mod( + score_mod_list: Optional[list[FlexAttentionMod]], + query_offset: torch.Tensor, + sequence_id_info: Optional[dict[str, torch.Tensor]], +): + if score_mod_list is None: + return None + wrapped_score_mod = None + for i, score_mod in enumerate(score_mod_list): + if i == 0: + wrapped_score_mod = partial( + score_mod.mod_fn, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ) + else: + wrapped_score_mod = _wrap_score_mod_fns( + wrapped_score_mod, # type: ignore + partial( + score_mod.mod_fn, + query_offset=query_offset, + sequence_id_info=sequence_id_info, + ), + ) + + return wrapped_score_mod + + +def _wrap_score_mod_fns( + score_mod_fn_1: _score_mod_signature, + score_mod_fn_2: _score_mod_signature, +) -> _score_mod_signature: + + def wrapped_score_mod_fn( + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + score = score_mod_fn_1(score, b, h, q_idx, kv_idx) + score = score_mod_fn_2(score, b, h, q_idx, kv_idx) + return score + + return wrapped_score_mod_fn diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index df3a485889..ef8a9b08dd 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -7,6 +7,8 @@ import warnings from typing import Any, Optional, Union +import torch +from packaging import version from transformers import PretrainedConfig from llmfoundry.layers_registry import ffns_with_megablocks @@ -298,10 +300,16 @@ def _validate_config(self) -> None: raise ValueError( "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1", ) - if self.attn_config['attn_impl'] not in ['torch', 'flash']: + if self.attn_config['attn_impl'] not in ['torch', 'flash', 'flex']: raise ValueError( f"Unknown attn_impl={self.attn_config['attn_impl']}", ) + if self.attn_config['attn_type'] == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.6.0'): + raise RuntimeError( + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + ) if self.attn_config['alibi'] and not check_alibi_support( self.attn_config['attn_impl'], ): @@ -309,7 +317,8 @@ def _validate_config(self) -> None: 'alibi only implemented with torch and flash (v2.4.2 or higher) attention.', ) if self.attn_config['attn_uses_sequence_id'] and not ( - self.attn_config['attn_impl'] == 'torch' or ( + self.attn_config['attn_impl'] == 'torch' or + self.attn_config['attn_impl'] == 'flex' or ( self.attn_config['attn_impl'] == 'flash' and is_flash_v2_installed(v2_version='v2.1.2') ) @@ -434,6 +443,7 @@ def allowed_block_overrides(self): 'attn_config': { 'sliding_window_size': None, 'reuse_kv_layer_idx': None, + 'flex_attn_mod_list': None, 'attn_temperature_tuning': { 'floor_scale': None, 'attn_scale': None, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4242cdf72e..a2b1578886 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -26,9 +26,13 @@ from composer.models import HuggingFaceModel from composer.utils import dist from tabulate import tabulate +from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from llmfoundry.layers_registry import ffns_with_megablocks +from llmfoundry.layers_registry import ( + ffns_with_megablocks, +) from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.flex_attn_utils import FLEX_ATTN_COMPILE if is_flash_v2_installed(): try: # This try...except is needed because transformers requires it despite the 'if' statement above @@ -250,8 +254,9 @@ def _get_attn_mask_in_len_seq_one_hot( ): attention_mask_in_length = None sequence_id_one_hot = None - if (sequence_id - is not None) and attn_uses_sequence_id and (attn_impl == 'flash'): + if (sequence_id is not None) and attn_uses_sequence_id and ( + attn_impl == 'flash' or attn_impl == 'flex' + ): # Check if sequence has left padding. If yes, raise an error. if (attention_mask is not None ) and (attention_mask[:, 0].sum() != attention_mask.shape[0]): @@ -287,6 +292,7 @@ def _get_attn_mask_in_len_seq_one_hot( def gen_sequence_id_info( sequence_id: Union[None, torch.Tensor], + bsz: int, S: int, attn_uses_sequence_id: bool, attn_impl: str, @@ -307,7 +313,7 @@ def gen_sequence_id_info( pos_id_within_seq = pos_id_within_seq.sum(dim=-1) - 1 return attention_mask_in_length, pos_id_within_seq - return None, torch.arange(S, device=device)[None, :] + return None, torch.arange(S, device=device).repeat(bsz, 1) def gen_flash_attn_padding_info( @@ -458,6 +464,18 @@ def __init__(self, config: MPTConfig): self.mb_args = None self.shift_labels = True + flex_attn_compile = config.attn_config.get( + 'flex_attn_compile', + FLEX_ATTN_COMPILE, + ) + if self.attn_impl == 'flex': + self.compiled_flex_attention = torch.compile( + flex_attention, + ) if flex_attn_compile else flex_attention + self.compiled_create_block_mask = torch.compile( + create_block_mask, + ) if flex_attn_compile else create_block_mask + self.blocks = self.construct_blocks( config=config, ) @@ -763,7 +781,7 @@ def _attn_bias( self._attn_bias_initialized = True # flash will incorporate any attention_mask inside the attention module - if self.attn_impl == 'flash': + if self.attn_impl == 'flash' or self.attn_impl == 'flex': return self.attn_bias, attention_mask if self.attn_bias is not None: @@ -962,6 +980,7 @@ def forward( ) attention_mask_in_length, pos_id_within_seq = gen_sequence_id_info( sequence_id=sequence_id, + bsz=bsz, S=S, attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl, @@ -970,7 +989,9 @@ def forward( ) alibi_slopes = None # alibi_slopes will only be used by flash attention for ALiBi - if self.alibi and self.attn_impl == 'flash': + if self.alibi and ( + self.attn_impl == 'flash' or self.attn_impl == 'flex' + ): alibi_slopes = gen_slopes( n_heads=self.config.n_heads, alibi_bias_max=self.alibi_bias_max, @@ -1020,6 +1041,19 @@ def forward( extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + if self.attn_impl == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'sequence_id_info': { + 'pos_in_seq': + pos_id_within_seq, + 'sequence_id': + sequence_id if self.attn_uses_sequence_id else None, + }, + 'compiled_flex_attention': + self.compiled_flex_attention, + 'compiled_create_block_mask': + self.compiled_create_block_mask, + } if pos_id_within_seq is not None: extra_kwargs['pos_id_within_seq'] = pos_id_within_seq x, attn_weights, present = block( diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index 67b9811df8..d65f815f98 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -6,6 +6,8 @@ ffn_config_defaults: dict = { 'ffn_type': 'mptmlp', } +import torch +from packaging import version attn_config_defaults: dict = { 'attn_type': 'multihead_attention', @@ -34,6 +36,10 @@ 'type': 'no_scaling', 'factor': 1.0, }, + 'flex_attn_mod_list': [], + 'flex_attn_compile': + version.parse(torch.__version__.split('.dev')[0]) >= + version.parse('2.6.0'), 'attn_temperature_tuning': { 'floor_scale': 8192, 'attn_scale': 0.0, diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 850c4f3bbd..4cf542d34a 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -22,6 +22,7 @@ ffns, ffns_with_megablocks, ffns_with_norm, + flex_attention_mods, module_init_fns, norms, param_init_fns, @@ -432,6 +433,7 @@ 'ffns_with_megablocks', 'attention_classes', 'attention_implementations', + 'flex_attention_mods', 'fcs', 'icl_datasets', 'config_transforms', diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py index 871a320b56..7e4acc453b 100644 --- a/tests/models/layers/test_attention.py +++ b/tests/models/layers/test_attention.py @@ -6,18 +6,27 @@ import pytest import torch from composer.utils import reproducibility +from packaging import version +from torch.nn.attention.flex_attention import create_block_mask, flex_attention from llmfoundry.models.layers.attention import ( apply_temperature_tuning, attention_implementations, scaled_multihead_dot_product_attention, ) +from llmfoundry.models.layers.flex_attn_utils import FLEX_ATTN_COMPILE from llmfoundry.models.layers.layer_builders import build_attention_layer from llmfoundry.models.mpt.modeling_mpt import ( gen_flash_attn_padding_info, gen_sequence_id_info, ) +compiled_flex_attention = flex_attention +compiled_create_block_mask = create_block_mask +if FLEX_ATTN_COMPILE: + compiled_flex_attention = torch.compile(flex_attention) + compiled_create_block_mask = torch.compile(create_block_mask) + @pytest.mark.parametrize( 'attn_name', @@ -174,14 +183,20 @@ def test_unfused_wqkv(attn_name: str, dim: int): @pytest.mark.gpu @pytest.mark.parametrize('sliding_window_size', [1, 4, 8]) -@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch', 'flex']) def test_sliding_window(sliding_window_size: int, attn_impl: str): # Test that sliding window attention works as expected. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) dtype = torch.bfloat16 device = 'cuda' d = 128 n_heads = 8 - seqlen_1 = 8 + seqlen_1 = 8 if attn_impl != 'flex' else 128 # FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work bsz = 2 query_1 = torch.randn(bsz, seqlen_1, @@ -209,6 +224,13 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): 'should_repeat_kv_for_gqa': True, } + elif attn_impl == 'flex': + attn_extra_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, @@ -274,6 +296,140 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): _assert_approx_equal(value_1.grad, value_2.grad) +@pytest.mark.gpu +@pytest.mark.parametrize('sliding_window_size', [1, 4]) +@pytest.mark.parametrize('global_window_size', [1, 4]) +@pytest.mark.parametrize('attn_impl', ['flex']) +def test_local_global_window( + sliding_window_size: int, + global_window_size: int, + attn_impl: str, +): + # Test that sliding window attention works as expected. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) + + dtype = torch.bfloat16 + device = 'cuda' + d = 128 + n_heads = 8 + seqlen_1 = 128 + bsz = 1 + + query_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + value_1.requires_grad = True + + attn_extra_kwargs = {} + if attn_impl == 'flex': + attn_extra_kwargs = { + 'compiled_flex_attention': + compiled_flex_attention, + 'compiled_create_block_mask': + compiled_create_block_mask, + 'flex_attn_compile': + FLEX_ATTN_COMPILE, + 'sequence_id_info': { + 'pos_in_seq': None, + }, + 'flex_attn_mod_list': [{ + 'mod_name': 'local_global_mask', + 'mod_kwargs': { + 'sliding_window_size': sliding_window_size, + 'global_window_size': global_window_size, + }, + },], + } + + output_1, _, _ = attention_implementations.get(attn_impl)( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + sliding_window_size=-1, + **attn_extra_kwargs, + ) + + output_1.sum().backward() + + query_2 = query_1.detach().clone() + query_2.requires_grad = True + key_2 = key_1.detach().clone() + key_2.requires_grad = True + value_2 = value_1.detach().clone() + value_2.requires_grad = True + + global_bias_2 = torch.where( + torch.arange(seqlen_1)[None, None, + None, :].to(dtype=dtype, device=device) + <= global_window_size, + torch.ones(1, 1, seqlen_1, + seqlen_1).to(dtype=torch.bool, device=device), + torch.zeros(1, 1, seqlen_1, + seqlen_1).to(dtype=torch.bool, device=device), + ) + + window_mask_2 = torch.tril( + torch.ones(seqlen_1, seqlen_1), + diagonal=-(sliding_window_size + 1), + ).to(dtype=torch.bool, device=device) + window_mask_2 = torch.where( + window_mask_2, + torch.zeros_like(window_mask_2), + torch.ones_like(window_mask_2), + ) + attn_bias_2 = global_bias_2 | window_mask_2 + attn_bias_2 = torch.where( + attn_bias_2, + torch.zeros_like(attn_bias_2, dtype=dtype), + torch.finfo(dtype).min, + ) + output_2, _, _ = scaled_multihead_dot_product_attention( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=attn_bias_2, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + ) + + output_2.sum().backward() + + _assert_approx_equal(output_1, output_2) + assert (query_2.grad is not None) and (query_1.grad is not None) + _assert_approx_equal(query_1.grad, query_2.grad) + assert (key_2.grad is not None) and (key_1.grad is not None) + _assert_approx_equal(key_1.grad, key_2.grad) + assert (value_2.grad is not None) and (value_1.grad is not None) + _assert_approx_equal(value_1.grad, value_2.grad) + + def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor): assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2) @@ -409,6 +565,7 @@ def test_gen_sequence_id_info(attn_uses_sequence_id: bool): _, pos_id_within_seq = gen_sequence_id_info( sequence_id=sequence_id, + bsz=n, S=s, attn_uses_sequence_id=attn_uses_sequence_id, attn_impl='flash', diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index 666d93c9b4..abab741a94 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -6,31 +6,47 @@ import pytest import torch +from packaging import version +from torch.nn.attention.flex_attention import create_block_mask, flex_attention from llmfoundry.models.layers.attention import ( + attention_implementations, attn_bias_shape, build_attn_bias, check_alibi_support, - flash_attn_fn, gen_slopes, is_flash_v2_installed, scaled_multihead_dot_product_attention, ) +from llmfoundry.models.layers.flex_attn_utils import FLEX_ATTN_COMPILE from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info +compiled_flex_attention = flex_attention +compiled_create_block_mask = create_block_mask +if FLEX_ATTN_COMPILE: + compiled_flex_attention = torch.compile(flex_attention) + compiled_create_block_mask = torch.compile(create_block_mask) + @pytest.mark.gpu @pytest.mark.skipif( not is_flash_v2_installed(), reason='GQA natively only supported by Flash Attention after v2.', ) +@pytest.mark.parametrize('attn_impl', ['flash', 'flex']) @pytest.mark.parametrize('kv_n_heads', [1, 4, 8]) -def test_gqa_kv_repetition(kv_n_heads: int): +def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int): # Test that flash attention v2 with GQA (kv_n_heads < n_heads) works the same # whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) d = 128 n_heads = 8 - seqlen_1 = 6 + seqlen_1 = 6 if attn_impl != 'flex' else 128 # FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work bsz = 2 query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda() @@ -41,7 +57,30 @@ def test_gqa_kv_repetition(kv_n_heads: int): kv_n_heads * d).to(torch.bfloat16).cuda() value_1.requires_grad = True - output_1, _, _ = flash_attn_fn( + extra_attn_kwargs = {} + if attn_impl == 'flash': + extra_attn_kwargs = { + 'flash_attn_padding_info': + gen_flash_attn_padding_info( + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), + 'should_repeat_kv_for_gqa': + True, + } + elif attn_impl == 'flex': + extra_attn_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } + + output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, key=key_1, value=value_1, @@ -55,15 +94,7 @@ def test_gqa_kv_repetition(kv_n_heads: int): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, - seqlen_1, - 0, - query_1.device, - None, - None, - ), - should_repeat_kv_for_gqa=True, + **extra_attn_kwargs, ) output_1.sum().backward() @@ -74,8 +105,30 @@ def test_gqa_kv_repetition(kv_n_heads: int): key_2.requires_grad = True value_2 = value_1.detach().clone() value_2.requires_grad = True - - output_2, _, _ = flash_attn_fn( + extra_attn_kwargs = {} + if attn_impl == 'flash': + extra_attn_kwargs = { + 'flash_attn_padding_info': + gen_flash_attn_padding_info( + bsz, + seqlen_1, + 0, + query_2.device, + None, + None, + ), + 'should_repeat_kv_for_gqa': + False, + } + elif attn_impl == 'flex': + extra_attn_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } + + output_2, _, _ = attention_implementations.get(attn_impl)( query=query_2, key=key_2, value=value_2, @@ -89,15 +142,7 @@ def test_gqa_kv_repetition(kv_n_heads: int): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, - seqlen_1, - 0, - query_2.device, - None, - None, - ), - should_repeat_kv_for_gqa=False, + **extra_attn_kwargs, ) output_2.sum().backward() @@ -113,12 +158,19 @@ def test_gqa_kv_repetition(kv_n_heads: int): reason= 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.', ) -def test_seq_id_masking_FA_v2(): +@pytest.mark.parametrize('attn_impl', ['flash', 'flex']) +def test_seq_id_masking_FA_v2(attn_impl: str): # Test that flash attention v2 with sequence id masking works correctly. - d = 128 + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) + d = 128 # TODO: Compiled FlexAttention works for d=16 with seqlen=6, but not for d=128 with seqlen=6. For seqlen=128, all d's in [16, 32, 64, 128, 256] work. Probably because this is not yet fixed: https://pytorch.org/blog/flexattention/#limitations-and-future-work n_heads = 4 kv_n_heads = 4 - seqlen_1 = 6 + seqlen_1 = 128 bsz = 2 query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda() @@ -134,9 +186,13 @@ def test_seq_id_masking_FA_v2(): (3, 5), (5, 6), ] # Each batch has 3 sequences of length 3, 2, and 1 respectively. - attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0], - [3, 2, 1, 0, 0, - 0]]).to(torch.int64).cuda() + attention_mask_in_length_1 = torch.tensor([ + [3, 2, 1] + [0] * (seqlen_1 - 3), + [3, 2, 1] + [0] * (seqlen_1 - 3), + ]).to(torch.int64).cuda() + sequence_id = torch.tensor([[0, 0, 0, 1, 1, 2] + [-1] * + (seqlen_1 - 6), [0, 0, 0, 1, 1, 2] + [-1] * + (seqlen_1 - 6)],).to(torch.int64).cuda() flash_attn_padding_info_1 = gen_flash_attn_padding_info( bsz, @@ -146,8 +202,19 @@ def test_seq_id_masking_FA_v2(): attention_mask_in_length_1, None, ) - - output_1, _, _ = flash_attn_fn( + extra_attn_kwargs = {} + if attn_impl == 'flash': + extra_attn_kwargs['flash_attn_padding_info'] = flash_attn_padding_info_1 + elif attn_impl == 'flex': + extra_attn_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': { + 'sequence_id': sequence_id, + }, + } + output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, key=key_1, value=value_1, @@ -161,7 +228,7 @@ def test_seq_id_masking_FA_v2(): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=flash_attn_padding_info_1, + **extra_attn_kwargs, ) output_1.sum().backward() @@ -182,8 +249,11 @@ def test_seq_id_masking_FA_v2(): None, None, ) - - output_2, _, _ = flash_attn_fn( + attn_impl = 'flash' + extra_attn_kwargs = { + 'flash_attn_padding_info': flash_attn_padding_info_2, + } + output_2, _, _ = attention_implementations.get(attn_impl)( query=query_2, key=key_2, value=value_2, @@ -197,11 +267,11 @@ def test_seq_id_masking_FA_v2(): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=flash_attn_padding_info_2, + **extra_attn_kwargs, ) output_2.sum().backward() - assert torch.allclose( + torch.testing.assert_close( output_1[:, seq_range[0]:seq_range[1], :], output_2, ) @@ -224,13 +294,24 @@ def test_seq_id_masking_FA_v2(): not check_alibi_support('flash'), reason='ALiBi only supported by Flash Attention after v2.4.2.', ) +@pytest.mark.parametrize('attn_impl', ['flash', 'flex']) @pytest.mark.parametrize('n_heads', [1, 6, 8]) -def test_alibi_bias(n_heads: int): +def test_alibi_bias(attn_impl: str, n_heads: int): # Test that sliding window attention works as expected. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) + if attn_impl == 'flex' and n_heads != 8: + pytest.skip( + 'FlexAttention passes the test individually for n_heads=1, 6, and 8, but not when all three are configured.', + ) # TODO: Investigate why this is the case. dtype = torch.bfloat16 device = 'cuda' d = 128 - seqlen_1 = 8 + seqlen_1 = 6 if attn_impl != 'flex' else 128 # TODO: FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work bsz = 2 query_1 = torch.randn(bsz, seqlen_1, @@ -248,7 +329,29 @@ def test_alibi_bias(n_heads: int): device=torch.device(device), return_1d=True, ) - output_1, _, _ = flash_attn_fn( + extra_attn_kwargs = {} + if attn_impl == 'flash': + extra_attn_kwargs = { + 'flash_attn_padding_info': + gen_flash_attn_padding_info( + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), + 'should_repeat_kv_for_gqa': + True, + } + elif attn_impl == 'flex': + extra_attn_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } + output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, key=key_1, value=value_1, @@ -262,16 +365,8 @@ def test_alibi_bias(n_heads: int): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, - seqlen_1, - 0, - query_1.device, - None, - None, - ), - should_repeat_kv_for_gqa=True, alibi_slopes=alibi_slopes_1, + **extra_attn_kwargs, ) output_1.sum().backward() @@ -341,16 +436,34 @@ def gen_bias(): reason= 'attn_logit_softcapping only supported by Flash Attention after v2.6.2.', ) +@pytest.mark.parametrize('attn_impl', ['flash', 'flex']) @pytest.mark.parametrize( 'attn_logit_softcapping', [None, 0.1, 1.0, 10.0, 100.0], ) -def test_attn_logit_softcapping(attn_logit_softcapping: Optional[float]): +def test_attn_logit_softcapping( + attn_impl: str, + attn_logit_softcapping: Optional[float], +): # Test that attn_logit_softcapping in attention works as expected. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) + if attn_impl == 'flex' and attn_logit_softcapping is not None: + if int(attn_logit_softcapping) != attn_logit_softcapping: + pytest.skip( + 'FlexAttention does not support attn_logit_softcapping with float softcap temperature.', + ) + else: + attn_logit_softcapping = int(attn_logit_softcapping) + dtype = torch.bfloat16 device = 'cuda' d = 128 - seqlen_1 = 8 + seqlen_1 = 8 if attn_impl != 'flex' else 128 # FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work bsz = 2 n_heads = 4 @@ -363,7 +476,29 @@ def test_attn_logit_softcapping(attn_logit_softcapping: Optional[float]): value_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, device=device) value_1.requires_grad = True - output_1, _, _ = flash_attn_fn( + extra_attn_kwargs = {} + if attn_impl == 'flash': + extra_attn_kwargs = { + 'flash_attn_padding_info': + gen_flash_attn_padding_info( + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), + 'should_repeat_kv_for_gqa': + True, + } + elif attn_impl == 'flex': + extra_attn_kwargs = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } + output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, key=key_1, value=value_1, @@ -377,16 +512,8 @@ def test_attn_logit_softcapping(attn_logit_softcapping: Optional[float]): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, - seqlen_1, - 0, - query_1.device, - None, - None, - ), - should_repeat_kv_for_gqa=True, attn_logit_softcapping=attn_logit_softcapping, + **extra_attn_kwargs, ) output_1.sum().backward() diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index c1435b4702..0d42c53985 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -7,6 +7,8 @@ import torch from omegaconf import DictConfig from omegaconf import OmegaConf as om +from packaging import version +from torch.nn.attention.flex_attention import create_block_mask, flex_attention from llmfoundry.models.layers import attention from llmfoundry.models.layers.attention import ( @@ -14,6 +16,7 @@ gen_slopes, is_flash_v2_installed, ) +from llmfoundry.models.layers.flex_attn_utils import FLEX_ATTN_COMPILE from llmfoundry.models.layers.layer_builders import build_attention_layer from llmfoundry.models.mpt.modeling_mpt import ( apply_sequence_id, @@ -22,6 +25,12 @@ gen_rotary_embedding, ) +compiled_flex_attention = flex_attention +compiled_create_block_mask = create_block_mask +if FLEX_ATTN_COMPILE: + compiled_flex_attention = torch.compile(flex_attention) + compiled_create_block_mask = torch.compile(create_block_mask) + def allclose_helper( t0: torch.Tensor, @@ -74,9 +83,13 @@ def gen_bias( @pytest.mark.gpu -@pytest.mark.parametrize('attn_impl_0, attn_impl_1', [ - ('flash', 'torch'), -]) +@pytest.mark.parametrize( + 'attn_impl_0, attn_impl_1', + [ + ('flash', 'torch'), + ('flex', 'torch'), + ], +) @pytest.mark.parametrize('clip_qkv', [True, False]) @pytest.mark.parametrize( 'qk_ln, qk_gn', @@ -140,6 +153,12 @@ def test_attn_impl( Includes testing with and without attn_clip_qkv, attn_qk_ln, attn_qk_gn, alibi, and rope. """ + if (attn_impl_0 == 'flex' or attn_impl_1 == 'flex') and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] if alibi and not ( @@ -161,7 +180,7 @@ def test_attn_impl( pytest.skip('attn_uses_sequence_id requires alibi or rope.') cfg = om.create({ - 'attn_impl': 'flash', + 'attn_impl': attn_impl_0, 'd_model': 64, 'n_heads': 4, 'attn_pdrop': 0, @@ -264,7 +283,7 @@ def test_attn_impl( sequence_id, ) alibi_slopes_0 = None - if alibi and attn_impl_0 == 'flash': + if alibi and (attn_impl_0 == 'flash' or attn_impl_0 == 'flex'): alibi_slopes_0 = gen_slopes( n_heads=cfg.n_heads, alibi_bias_max=8, @@ -298,7 +317,17 @@ def test_attn_impl( 'seq_len': s, } - + extra_kwargs = {} + if attn_impl_0 == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } + if sequence_id is not None: + extra_kwargs['flex_attn_kwargs']['sequence_id_info'][ + 'sequence_id'] = sequence_id y0, _, _ = attn0( x0, past_key_value=None, @@ -308,6 +337,7 @@ def test_attn_impl( is_causal=True, flash_attn_padding_info=flash_attn_padding_info_0, alibi_slopes=alibi_slopes_0, + **extra_kwargs, ) attn_bias_1 = gen_bias( attn_impl_1, @@ -319,13 +349,26 @@ def test_attn_impl( sequence_id, ) alibi_slopes_1 = None - if alibi and attn_impl_1 == 'flash': + if alibi and (attn_impl_1 == 'flash' or attn_impl_1 == 'flex'): alibi_slopes_1 = gen_slopes( n_heads=cfg.n_heads, alibi_bias_max=8, device=torch.device(device), return_1d=True, ) + + extra_kwargs = {} + if attn_impl_1 == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + } + if sequence_id is not None: + extra_kwargs['flex_attn_kwargs']['sequence_id_info'] = { + 'sequence_id': sequence_id, + } + y1, _, _ = attn1( x1, past_key_value=None, @@ -335,6 +378,7 @@ def test_attn_impl( is_causal=True, flash_attn_padding_info=flash_attn_padding_info_1, alibi_slopes=alibi_slopes_1, + **extra_kwargs, ) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) @@ -372,9 +416,15 @@ def test_attn_impl( @pytest.mark.gpu -@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch', 'flex']) def test_vs_mha(attn_impl: str, device: str = 'cuda'): """Compare diff attn_impl to torch.nn.MultiheadAttention.""" + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) from llmfoundry.models.layers import attention cfg = om.create({ @@ -429,6 +479,14 @@ def gen_tca_mask(): None, attention_mask, ) + extra_kwargs = {} + if attn_impl == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } y0, _, _ = mmhsa( x0, past_key_value=None, @@ -436,6 +494,7 @@ def gen_tca_mask(): attention_mask=attention_mask, is_causal=True, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) y1, _ = tmhsa( x1, @@ -481,7 +540,7 @@ def gen_tca_mask(): @pytest.mark.gpu -@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch', 'flex']) @pytest.mark.parametrize('n_heads', [16, 8]) @pytest.mark.parametrize('kv_n_heads', [4, 2, 1]) def test_grouped_attention_heads( @@ -491,6 +550,12 @@ def test_grouped_attention_heads( device: str = 'cuda', ): """Ensure grouped_query_attention runs w/ diff n_heads & kv_n_heads.""" + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) from llmfoundry.models.layers import attention cfg = om.create({ @@ -522,6 +587,14 @@ def test_grouped_attention_heads( None, attention_mask, ) + extra_kwargs = {} + if attn_impl == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': {}, + } y0, _, _ = mmhsa( x0, past_key_value=None, @@ -529,6 +602,7 @@ def test_grouped_attention_heads( attention_mask=attention_mask, is_causal=True, flash_attn_padding_info=flash_attn_padding_info, + **extra_kwargs, ) y0 *= attention_mask.unsqueeze(-1) @@ -589,13 +663,19 @@ def test_grouped_query_invalid_heads(): }, }], ) -@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch', 'flex']) def test_reuse_prev_layer_kv_cache( pos_emb_config: dict, attn_impl: str, device: str = 'cuda', ): """Checks reusing previous layer's kv cache.""" + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] @@ -637,7 +717,6 @@ def test_reuse_prev_layer_kv_cache( attn1.load_state_dict(attn0_sd) attention_mask = torch.ones(n, s).to(device).bool() - attention_mask_in_length = gen_attention_mask_in_length( sequence_id=sequence_id, S=s, @@ -705,7 +784,16 @@ def test_reuse_prev_layer_kv_cache( 'seq_len': s, } - + extra_kwargs = {} + if attn_impl == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': { + 'sequence_id': sequence_id, + }, + } y0, _, prev_layer_key_value = attn0( x0, past_key_value=(), @@ -715,6 +803,7 @@ def test_reuse_prev_layer_kv_cache( is_causal=True, flash_attn_padding_info=flash_attn_padding_info, alibi_slopes=alibi_slopes_0, + **extra_kwargs, ) attn_bias_1 = gen_bias( attn_impl, @@ -737,6 +826,16 @@ def test_reuse_prev_layer_kv_cache( prev_layer_key_value = [ t.clone().detach() for t in prev_layer_key_value ] + extra_kwargs = {} + if attn_impl == 'flex': + extra_kwargs['flex_attn_kwargs'] = { + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, + 'flex_attn_compile': FLEX_ATTN_COMPILE, + 'sequence_id_info': { + 'sequence_id': sequence_id, + }, + } y1, _, _ = attn1( x1, past_key_value=None, @@ -747,6 +846,7 @@ def test_reuse_prev_layer_kv_cache( flash_attn_padding_info=flash_attn_padding_info, alibi_slopes=alibi_slopes_1, prev_layer_key_value=prev_layer_key_value, + **extra_kwargs, ) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 2697956850..366667bb21 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -27,6 +27,7 @@ ) from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from packaging import version from transformers import ( AutoModelForCausalLM, PreTrainedModel, @@ -47,6 +48,7 @@ is_flash_v2_installed, ) from llmfoundry.models.layers.blocks import MPTBlock +from llmfoundry.models.layers.flex_attn_utils import FLEX_ATTN_COMPILE from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel from llmfoundry.models.mpt.modeling_mpt import ( CROSS_ENTROPY_IGNORE_INDEX, @@ -72,6 +74,7 @@ def _get_objs( conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', model_config_overrides: Optional[dict] = None, attn_impl: str = 'torch', + flex_attn_compile: bool = FLEX_ATTN_COMPILE, ): warnings.filterwarnings( action='ignore', @@ -101,6 +104,7 @@ def _get_objs( test_cfg.precision = 'amp_bf16' if is_gpu else 'fp32' test_cfg.model.attn_config = { 'attn_impl': attn_impl, + 'flex_attn_compile': flex_attn_compile, } test_cfg.model.init_device = device test_cfg.device = device @@ -468,7 +472,9 @@ def test_full_forward_and_backward_t5_small( 'attn_impl,precision', [('torch', torch.float16), ('torch', torch.bfloat16), pytest.param('flash', torch.float16, marks=pytest.mark.gpu), - pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)], + pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu), + pytest.param('flex', torch.float16, marks=pytest.mark.gpu), + pytest.param('flex', torch.bfloat16, marks=pytest.mark.gpu)], ) @pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptglu']) @pytest.mark.parametrize( @@ -500,12 +506,19 @@ def test_determinism( ffn_act_fn: dict, tiny_neox_tokenizer: PreTrainedTokenizerBase, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: test_cfg = om.load(f) test_cfg.model.attn_config = { 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, } if hasattr(test_cfg.model, 'ffn_config'): test_cfg.model.ffn_config['ffn_type'] = ffn_type @@ -959,7 +972,7 @@ def test_mb_mpt_creation(): @pytest.mark.gpu -@pytest.mark.parametrize('attention_impl', ['flash', 'torch']) +@pytest.mark.parametrize('attention_impl', ['flash', 'torch', 'flex']) @pytest.mark.parametrize( 'pos_emb_config', [{ @@ -1005,6 +1018,12 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): pytest.skip( 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.', ) + if attention_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) composer_device = get_device(None) @@ -1020,6 +1039,7 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): attn_config={ 'attn_impl': attention_impl, 'attn_uses_sequence_id': True, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, init_config={ @@ -1088,6 +1108,7 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), pytest.param('torch', marks=pytest.mark.gpu), ], ) @@ -1127,6 +1148,12 @@ def test_forward_with_padding( tie_word_embeddings: bool, ): # Test that different placement of padding does not affect the output. + if attention_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) alibi = pos_emb_config['alibi'] if alibi and not check_alibi_support(attention_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -1151,6 +1178,7 @@ def test_forward_with_padding( resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, init_config={ @@ -1344,6 +1372,7 @@ def test_forward_with_padding( [ ('torch', 'fp32'), pytest.param('flash', 'amp_bf16', marks=pytest.mark.gpu), + pytest.param('flex', 'amp_bf16', marks=pytest.mark.gpu), pytest.param('torch', 'amp_bf16', marks=pytest.mark.gpu), pytest.param('torch', 'fp32', marks=pytest.mark.gpu), ], @@ -1386,6 +1415,12 @@ def test_generate( ): # Test that generate works, and produces the same output with or without # padding in the input. + if attention_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) if pos_emb_config['alibi'] and not check_alibi_support(attention_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -1398,7 +1433,9 @@ def test_generate( pytest.skip(f'This test configuration has precision / sampling issues.') composer_device = get_device(None) - + reproducibility.seed_all( + 4, + ) # Flex atttention fails for the default seed, but works for all the other seeds tested. Probably the output logit softmax score is such that a slight numerical imprecision changes the output. hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1410,6 +1447,8 @@ def test_generate( resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, + 'flex_attn_compile': + False, # TODO: Needs these issues to be fixed: https://github.com/pytorch/pytorch/issues/139064, https://github.com/pytorch/pytorch/issues/139544. Causes errors even with dynamic=True and/or fullgraph=True. **pos_emb_config, }, tie_word_embeddings=tie_word_embeddings, @@ -1640,6 +1679,7 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize( @@ -1680,6 +1720,12 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.', ) + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) composer_device = get_device(None) @@ -1694,6 +1740,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -1802,6 +1849,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict): [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize( @@ -1841,6 +1889,12 @@ def test_forward_with_cache( ): # Test that model forward with and without the key-value cache produces the # same output. + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -1863,6 +1917,7 @@ def test_forward_with_cache( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -1973,6 +2028,7 @@ def test_forward_with_cache( [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize( @@ -2010,6 +2066,12 @@ def test_generate_with_past_kv( pos_emb_config: dict, tie_word_embeddings: bool, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') if pos_emb_config['rope'] and pos_emb_config[ @@ -2031,6 +2093,7 @@ def test_generate_with_past_kv( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -2095,6 +2158,7 @@ def test_generate_with_past_kv( [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize( @@ -2142,6 +2206,12 @@ def test_generation_kwargs_dont_crash( pos_emb_config: dict, tie_word_embeddings: bool, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') @@ -2166,6 +2236,7 @@ def test_generation_kwargs_dont_crash( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -2320,6 +2391,7 @@ def test_alibi_vs_hf(): [ 'torch', pytest.param('flash', marks=pytest.mark.gpu), + pytest.param('flex', marks=pytest.mark.gpu), pytest.param('torch', marks=pytest.mark.gpu), ], ) @@ -2356,9 +2428,15 @@ def test_forward_with_output_attentions_and_output_hidden_states( attn_impl: str, pos_emb_config: dict, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) if pos_emb_config['alibi'] and not check_alibi_support(attn_impl): pytest.skip(f'flash attention below v2.4.2 does not support alibi.') - if attn_impl == 'flash': + if attn_impl == 'flash' or attn_impl == 'flex': pytest.skip(f'output_attentions only implemented with torch attention.') if pos_emb_config['rope'] and pos_emb_config[ 'rope_impl'] == 'dail' and not is_flash_v2_installed(): @@ -2381,6 +2459,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, **pos_emb_config, }, use_cache=True, @@ -2515,10 +2594,18 @@ def test_hf_init( @pytest.mark.gpu +@pytest.mark.parametrize('attn_impl', ['torch', 'flash', 'flex']) def test_head_dim_8_flash_mqa_attn( + attn_impl: str, tiny_neox_tokenizer: PreTrainedTokenizerBase, batch_size: int = 2, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') test_cfg.device = torch.cuda.current_device() @@ -2534,7 +2621,8 @@ def test_head_dim_8_flash_mqa_attn( emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': 'flash', + 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, 'attn_type': 'multiquery_attention', }, ) @@ -2562,7 +2650,14 @@ def test_head_dim_8_flash_mqa_attn( assert not torch.isnan(output.logits).any() -def test_construct_blocks(): +@pytest.mark.parametrize('attn_impl', ['torch', 'flash', 'flex']) +def test_construct_blocks(attn_impl: str): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) n_layers = 13 config = MPTConfig( @@ -2572,7 +2667,8 @@ def test_construct_blocks(): expansion_ratio=2, max_seq_len=64, attn_config={ - 'attn_impl': 'flash', + 'attn_impl': attn_impl, + 'flex_attn_compile': FLEX_ATTN_COMPILE, 'attn_type': 'grouped_query_attention', 'kv_n_heads': 4, }, @@ -2665,10 +2761,18 @@ def test_construct_blocks(): @pytest.mark.gpu +@pytest.mark.parametrize('attn_impl', ['torch', 'flash', 'flex']) def test_reuse_prev_layer_kv_cache( + attn_impl: str, request: pytest.FixtureRequest, batch_size: int = 2, ): + if attn_impl == 'flex' and version.parse( + torch.__version__.split('.dev')[0], + ) < version.parse('2.5.1'): + pytest.skip( + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', + ) conf_path = 'scripts/train/yamls/pretrain/testing.yaml' model_config_overrides = { 'block_overrides': { @@ -2694,7 +2798,8 @@ def test_reuse_prev_layer_kv_cache( request=request, conf_path=conf_path, model_config_overrides=model_config_overrides, - attn_impl='flash', + attn_impl=attn_impl, + flex_attn_compile=FLEX_ATTN_COMPILE, ) batch = gen_random_batch(batch_size, test_cfg) diff --git a/tests/test_registry.py b/tests/test_registry.py index 90ef3bfaac..8ddce5125d 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -40,6 +40,7 @@ def test_expected_registries_exist(): 'ffns', 'ffns_with_norm', 'ffns_with_megablocks', + 'flex_attention_mods', 'attention_classes', 'attention_implementations', 'fcs',