diff --git a/vllm_gaudi/extension/unified.py b/vllm_gaudi/extension/unified.py index 8c52aef3..6a45c8f6 100644 --- a/vllm_gaudi/extension/unified.py +++ b/vllm_gaudi/extension/unified.py @@ -187,22 +187,32 @@ def partial_attn_unique(query: torch.tensor, blocks: torch.tensor, block_mapping kv_heads = cache_utils.kv_heads query = query.index_select(0, block_mapping).unflatten(1, (kv_heads, -1)).unsqueeze(-2) + head_dim = query.size(2) key, value = cache_utils.fetch_unique(blocks) block_mapping_2d = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size).to(query.dtype) attn = torch.matmul(query, key.transpose(-1, -2)) - attn = attn + bias.unsqueeze(1).unsqueeze(1).unsqueeze(1) - block_max = torch.maximum(attn.amax(-1), fmin) - attn = torch.exp(attn - block_max.unsqueeze(-1)) - block_sum = attn.sum(-1) + block_bias = bias.unsqueeze(1).unsqueeze(1).unsqueeze(1) + attn, block_max, block_sum = torch.ops.hpu.block_softmax(attn, block_bias, block_mapping, fp8_exp=False) + attn = torch.matmul(attn, value) + # Reshape outputs + block_max = block_max[:, :kv_heads * head_dim].view(-1, kv_heads, head_dim, + 1) # [num_blocks, kv_heads, head_dim, 1] + block_sum = block_sum[:, :kv_heads * head_dim].view(-1, kv_heads, head_dim, + 1) # [num_blocks, kv_heads, head_dim, 1] + group_max = reduce_max(block_max, batch_size, block_mapping) + block_adjustment = torch.exp(block_max - group_max.index_select(0, block_mapping)) + block_sum = block_sum * block_adjustment group_sum = block2batch(block_sum, block_mapping_2d) + attn = attn * block_adjustment.unsqueeze(-1) attn = block2batch(attn, block_mapping_2d) + return (attn.flatten(1, 3), group_max.flatten(1, 3), group_sum.flatten(1, 3))