Skip to content

FA3 #3623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 15, 2025
Merged

FA3 #3623

Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 199 additions & 22 deletions lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,23 @@
import torch

from lmdeploy.pytorch.distributed import get_tp_world_rank
from lmdeploy.utils import get_logger

from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata

logger = get_logger('lmdeploy')

use_fa3 = False
try:
# Now flash-attention only support FA3 for sm90a && cuda >= 12.3
if (torch.cuda.get_device_capability()[0] == 9) and (torch.version.cuda >= '12.3'):
import flash_attn_interface # noqa: F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The package of flash attention 3 is not well done, one simple way is to import sgl_kernel, which can be directly installed on pypi and used out of the box.

from sgl_kernel.flash_attn import flash_attn_with_kvcache

assert torch.ops.flash_attn_3 is not None
use_fa3 = True
except Exception:
logger.warning('For higher performance, please install FlashAttention-3 '
'https://github.com/Dao-AILab/flash-attention')


@dataclass
class TritonAttentionMetadata(AttentionMetadata):
Expand All @@ -25,6 +39,8 @@ class TritonAttentionMetadata(AttentionMetadata):
# flash mla
tile_scheduler_metadata: torch.Tensor = None
num_splits: torch.Tensor = None
cu_seqlens_q: torch.Tensor = None
cu_seqlens_k: torch.Tensor = None


def _cdiv(a, b):
Expand Down Expand Up @@ -89,7 +105,6 @@ def forward(
inplace: bool = True,
) -> torch.Tensor:
"""forward."""

block_offsets = attn_metadata.block_offsets
q_start_loc = attn_metadata.q_start_loc
fill_q_start_loc = q_start_loc
Expand Down Expand Up @@ -129,7 +144,6 @@ def forward(
q_shape = query.shape
o_shape = q_shape[:-1] + (self.v_head_size, )
attn_output = query.new_empty(o_shape)

is_decoding = attn_metadata.is_decoding
if not self.alibi:
if is_decoding:
Expand Down Expand Up @@ -286,7 +300,6 @@ def forward(

q_shape = query.shape
o_shape = q_shape[:-1] + (self.v_head_size, )
attn_output = query.new_empty(o_shape)

is_decoding = attn_metadata.is_decoding
if is_decoding:
Expand All @@ -302,7 +315,6 @@ def forward(
tile_scheduler_metadata=attn_metadata.tile_scheduler_metadata,
num_splits=attn_metadata.num_splits,
causal=True)

else:
BLOCK_BS = k_cache.size(1)
# pad one more block to avoid invalid kv visit
Expand All @@ -313,26 +325,179 @@ def forward(
kv_seqlens,
block_offsets,
start_loc=kv_start_loc,
out_size=out_size,
out_size=kv_flatten_size if use_fa3 else out_size,
out_dtype=query.dtype,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
flatten_kv_layout='bshd' if use_fa3 else 'bhsd',
)
self.flash_attention_fwd(
if use_fa3:
q_rope = query[:, :, self.v_head_size:]
q_nope = query[:, :, :self.v_head_size]
k_rope = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, self.v_head_size:]
c_kv = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, :self.v_head_size]
from flash_attn_interface import flash_attn_varlen_func
attn_output, _ = flash_attn_varlen_func(
q=q_rope,
k=k_rope,
v=c_kv,
qv=q_nope,
cu_seqlens_q=attn_metadata.cu_seqlens_q,
cu_seqlens_k=attn_metadata.cu_seqlens_k,
max_seqlen_q=max_q_seqlen,
max_seqlen_k=kv_flatten_size,
softmax_scale=self.scale,
causal=self.causal,
window_size=(-1, -1) if self.sliding_window is None else self.sliding_window,
softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping,
)
else:
attn_output = query.new_empty(o_shape)
self.flash_attention_fwd(
query,
flatten_k,
flatten_v,
attn_output,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_start_loc=kv_start_loc,
kv_seqlens=kv_seqlens,
max_seqlen=max_q_seqlen,
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logit_softcapping,
causal=self.causal,
)
return attn_output


class FA3Impl(TritonAttentionImpl):
"""Triton attention implementation."""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float = None,
num_kv_heads: int = None,
v_head_size: int = None,
alibi: bool = False,
sliding_window: int = None,
logit_softcapping: float = None,
causal: bool = True,
**kwargs,
):
assert alibi is False, 'alibi not supported for FA3'
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
v_head_size=v_head_size,
alibi=alibi,
sliding_window=sliding_window,
logit_softcapping=logit_softcapping,
causal=causal,
**kwargs,
)
from flash_attn_interface import flash_attn_varlen_func
self.flash_attn_varlen_func_v3 = flash_attn_varlen_func

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
inplace: bool = True,
) -> torch.Tensor:
"""forward."""
block_offsets = attn_metadata.block_offsets
q_start_loc = attn_metadata.q_start_loc
fill_q_start_loc = q_start_loc
q_seqlens = attn_metadata.q_seqlens
fill_seqlens = q_seqlens
kv_start_loc = attn_metadata.kv_start_loc
kv_seqlens = attn_metadata.kv_seqlens
kv_flatten_size = attn_metadata.kv_flatten_size
quant_policy = attn_metadata.quant_policy
if attn_metadata.is_decoding:
max_q_seqlen = 1
else:
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
fill_max_q_seqlen = max_q_seqlen
if attn_metadata.fill_seqlens is not None:
fill_seqlens = attn_metadata.fill_seqlens
fill_max_q_seqlen = key.numel() // (key.size(-1) * key.size(-2))
fill_q_start_loc = fill_seqlens.cumsum(0) - fill_seqlens
is_decoding = attn_metadata.is_decoding
# fill kv cache
if key is not None and value is not None:
self.fill_kv_cache(
key,
value,
k_cache,
v_cache,
fill_q_start_loc,
fill_seqlens,
kv_seq_length=kv_seqlens,
max_q_seq_length=fill_max_q_seqlen,
block_offsets=block_offsets,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
)

q_shape = query.shape
o_shape = q_shape[:-1] + (self.v_head_size, )
attn_output = query.new_empty(o_shape)

if is_decoding:
self.paged_attention_fwd(
query,
flatten_k,
flatten_v,
k_cache,
v_cache,
attn_output,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_start_loc=kv_start_loc,
block_offsets,
kv_seqlens=kv_seqlens,
max_seqlen=max_q_seqlen,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logit_softcapping,
)
else:
flatten_k, flatten_v = self.flatten_kv_cache(
k_cache,
v_cache,
kv_seqlens,
block_offsets,
start_loc=kv_start_loc,
out_size=kv_flatten_size,
out_dtype=query.dtype,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
flatten_kv_layout='bshd',
)
attn_output, _ = self.flash_attn_varlen_func_v3(
q=query,
k=flatten_k,
v=flatten_v,
cu_seqlens_q=attn_metadata.cu_seqlens_q,
cu_seqlens_k=attn_metadata.cu_seqlens_k,
max_seqlen_q=max_q_seqlen,
max_seqlen_k=kv_flatten_size,
softmax_scale=self.scale,
causal=self.causal,
window_size=(-1, -1) if self.sliding_window is None else self.sliding_window,
softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping,
)
return attn_output

Expand Down Expand Up @@ -366,13 +531,25 @@ def build(
logical_softcapping=logical_softcapping,
causal=causal,
**kwargs)
return TritonAttentionImpl(num_heads,
head_size,
scale=scale,
num_kv_heads=num_kv_heads,
v_head_size=v_head_size,
alibi=alibi,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
causal=causal,
**kwargs)
elif use_fa3 and not alibi:
return FA3Impl(num_heads,
head_size,
scale=scale,
num_kv_heads=num_kv_heads,
v_head_size=v_head_size,
alibi=alibi,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
causal=causal,
**kwargs)
else:
return TritonAttentionImpl(num_heads,
head_size,
scale=scale,
num_kv_heads=num_kv_heads,
v_head_size=v_head_size,
alibi=alibi,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
causal=causal,
**kwargs)
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def update_step_context(cls, step_context):
kv_seqlens = step_context.kv_seqlens
kv_start_loc = None
kv_flatten_size = None
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(q_seqlens, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(kv_seqlens, dim=0, dtype=torch.int32), (1, 0))
if not step_context.is_decoding:
kv_start_loc = kv_seqlens.cumsum(0) - kv_seqlens
kv_flatten_size = kv_seqlens.sum().item()
Expand All @@ -143,6 +145,8 @@ def update_step_context(cls, step_context):
kv_seqlens=kv_seqlens,
kv_flatten_size=kv_flatten_size,
quant_policy=step_context.kv_quant_policy,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if getattr(step_context.model_config, 'use_flash_mla', False) is True:
if step_context.is_decoding is True:
Expand Down
51 changes: 34 additions & 17 deletions lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def flatten_kv_cache(k_caches: Tensor,
k_scales_zeros: Tensor = None,
v_scales_zeros: Tensor = None,
quant_policy: Literal[0, 4, 8] = 0,
kv_layout: str = 'bshd'):
kv_layout: str = 'bshd',
flatten_kv_layout: str = 'bhsd'):
"""Recovery paged kv cache to normal kv cache."""
if kv_layout == 'bshd':
b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)
Expand Down Expand Up @@ -230,17 +231,34 @@ def flatten_kv_cache(k_caches: Tensor,
BLOCK_DK = triton.next_power_of_2(k_head_dim)
BLOCK_DV = triton.next_power_of_2(v_head_dim)
BLOCK_BS = k_caches.size(s_dim)

k_states = k_caches.new_empty(num_heads, out_size, k_head_dim, dtype=out_dtype)

grid = (num_blocks, batch_size, num_heads)
if quant_policy == 0:
shared_kv = k_caches.data_ptr() == v_caches.data_ptr() and v_head_dim < k_head_dim
if shared_kv:
shared_kv = k_caches.data_ptr() == v_caches.data_ptr() and v_head_dim < k_head_dim
if flatten_kv_layout == 'bhsd':
k_states = k_caches.new_empty(num_heads, out_size, k_head_dim, dtype=out_dtype)
if quant_policy == 0 and shared_kv:
v_states = k_states[..., :v_head_dim]
v_head_dim = 0
else:
v_states = v_caches.new_empty(num_heads, out_size, v_head_dim, dtype=out_dtype)
stride_koh = k_states.stride(0)
stride_kos = k_states.stride(1)
stride_voh = v_states.stride(0)
stride_vos = v_states.stride(1)
elif flatten_kv_layout == 'bshd':
k_states = k_caches.new_empty(out_size, num_heads, k_head_dim, dtype=out_dtype)
if quant_policy == 0 and shared_kv:
v_states = k_states[..., :v_head_dim]
v_head_dim = 0
else:
v_states = v_caches.new_empty(out_size, num_heads, v_head_dim, dtype=out_dtype)
stride_koh = k_states.stride(1)
stride_kos = k_states.stride(0)
stride_voh = v_states.stride(1)
stride_vos = v_states.stride(0)
else:
raise RuntimeError('Unsupported layout.')

grid = (num_blocks, batch_size, num_heads)
if quant_policy == 0:
_flatten_kv_cache[grid](
k_caches,
v_caches,
Expand All @@ -257,11 +275,11 @@ def flatten_kv_cache(k_caches: Tensor,
stride_vcs=v_caches.stride(s_dim),
stride_vch=v_caches.stride(h_dim),
stride_vcd=v_caches.stride(d_dim),
stride_koh=k_states.stride(0),
stride_kos=k_states.stride(1),
stride_koh=stride_koh,
stride_kos=stride_kos,
stride_kod=k_states.stride(2),
stride_voh=v_states.stride(0),
stride_vos=v_states.stride(1),
stride_voh=stride_voh,
stride_vos=stride_vos,
stride_vod=v_states.stride(2),
stride_boff=block_offsets.stride(0),
OUT_SIZE=out_size,
Expand All @@ -272,7 +290,6 @@ def flatten_kv_cache(k_caches: Tensor,
BLOCK_DV=BLOCK_DV,
)
else:
v_states = v_caches.new_empty(num_heads, out_size, v_head_dim, dtype=out_dtype)
_flatten_kv_cache_quant[grid](
k_caches,
v_caches,
Expand All @@ -299,11 +316,11 @@ def flatten_kv_cache(k_caches: Tensor,
stride_vszs=v_scales_zeros.stride(s_dim),
stride_vszh=v_scales_zeros.stride(h_dim),
stride_vszd=v_scales_zeros.stride(d_dim),
stride_koh=k_states.stride(0),
stride_kos=k_states.stride(1),
stride_koh=stride_koh,
stride_kos=stride_kos,
stride_kod=k_states.stride(2),
stride_voh=v_states.stride(0),
stride_vos=v_states.stride(1),
stride_voh=stride_voh,
stride_vos=stride_vos,
stride_vod=v_states.stride(2),
stride_boff=block_offsets.stride(0),
quant_policy=quant_policy,
Expand Down