⚡️ Speed up function eager_attention_forward by 12%
#136
+26
−6
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 12% (0.12x) speedup for
eager_attention_forwardinsrc/transformers/models/markuplm/modeling_markuplm.py⏱️ Runtime :
3.19 milliseconds→2.84 milliseconds(best of250runs)📝 Explanation and details
The optimized code achieves a 12% speedup through several key micro-optimizations that reduce overhead and memory operations:
What optimizations were applied:
nn.functional,query.dtype, andkey.shape[-2]in local variables to avoid repeated attribute lookups* scalingwithattn_weights.mul_(scaling)and+ causal_maskwithattn_weights.add_(causal_mask)to reduce memory allocation.to(query.dtype)when the query isn't already float32, avoiding unnecessary type conversion overheadnn.functional.dropoutcall entirely whendropout=0.0ormodule.training=False, eliminating a no-op function callWhy these optimizations work:
nn.functionaland tensor properties eliminates repeated dictionary lookupsmul_()andadd_()modify tensors without creating new ones, reducing memory bandwidth and allocation overheadPerformance characteristics from test results:
The optimizations are particularly effective for transformer attention layers where this function is called frequently in inference scenarios with no dropout.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import pytest # used for our unit tests
import torch
from torch import nn
from transformers.models.markuplm.modeling_markuplm import
eager_attention_forward
unit tests
Helper to create dummy module
class DummyModule(nn.Module):
def init(self, training=True):
super().init()
self.training = training
===============================
BASIC TEST CASES
===============================
def test_basic_identity_attention():
# Test with identity query/key/value and no mask, scaling=1, dropout=0
# Should produce softmax of identity matrix (i.e., one-hot along last dim)
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 3, 3
query = torch.eye(seq_len).reshape(1, 1, seq_len, seq_len)
key = torch.eye(seq_len).reshape(1, 1, seq_len, seq_len)
value = torch.arange(seq_len).float().reshape(1, 1, seq_len, 1)
attention_mask = None
scaling = 1.0
dropout = 0.0
def test_basic_uniform_attention():
# All query/key are zeros, so softmax should be uniform
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 2, 1, 4, 5
query = torch.zeros(batch_size, num_heads, seq_len, head_dim)
key = torch.zeros(batch_size, num_heads, seq_len, head_dim)
value = torch.arange(seq_len).float().reshape(1, 1, seq_len, 1).expand(batch_size, num_heads, seq_len, 1)
attention_mask = None
scaling = 1.0
dropout = 0.0
def test_basic_with_attention_mask():
# Test with attention mask that blocks some positions
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 2, 2
query = torch.ones(batch_size, num_heads, seq_len, head_dim)
key = torch.ones(batch_size, num_heads, seq_len, head_dim)
value = torch.tensor([[[[1],[2]],[[3],[4]]]], dtype=torch.float32)
# Mask blocks position 1 for all queries
attention_mask = torch.tensor([[[[0, -1e9],[0, -1e9]]]], dtype=torch.float32)
scaling = 1.0
dropout = 0.0
def test_basic_dropout_training_mode():
# Dropout should be applied only in training mode
module_train = DummyModule(training=True)
module_eval = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 2, 2
query = torch.ones(batch_size, num_heads, seq_len, head_dim)
key = torch.ones(batch_size, num_heads, seq_len, head_dim)
value = torch.ones(batch_size, num_heads, seq_len, head_dim)
attention_mask = None
scaling = 1.0
dropout = 0.5
===============================
EDGE TEST CASES
===============================
def test_edge_zero_length_sequence():
# Sequence length 0: should return empty tensors
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 0, 4
query = torch.empty(batch_size, num_heads, seq_len, head_dim)
key = torch.empty(batch_size, num_heads, seq_len, head_dim)
value = torch.empty(batch_size, num_heads, seq_len, head_dim)
attention_mask = None
scaling = 1.0
dropout = 0.0
def test_edge_single_element():
# Sequence length 1: softmax should be 1, output should be value
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 1, 3
query = torch.ones(batch_size, num_heads, seq_len, head_dim)
key = torch.ones(batch_size, num_heads, seq_len, head_dim)
value = torch.tensor([[[[5, 6, 7]]]], dtype=torch.float32)
attention_mask = None
scaling = 1.0
dropout = 0.0
def test_edge_high_scaling():
# Large scaling should push softmax to be very peaked
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 3, 3
query = torch.tensor([[[[10,0,0],[0,10,0],[0,0,10]]]], dtype=torch.float32)
key = torch.eye(seq_len).reshape(1, 1, seq_len, seq_len)
value = torch.arange(seq_len).float().reshape(1, 1, seq_len, 1)
attention_mask = None
scaling = 100.0
dropout = 0.0
def test_edge_negative_mask():
# Attention mask with large negative values should block attention
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 2, 2
query = torch.ones(batch_size, num_heads, seq_len, head_dim)
key = torch.ones(batch_size, num_heads, seq_len, head_dim)
value = torch.tensor([[[[1],[2]],[[3],[4]]]], dtype=torch.float32)
attention_mask = torch.tensor([[[[0, -1e9],[0, -1e9]]]], dtype=torch.float32)
scaling = 1.0
dropout = 0.0
def test_edge_incorrect_shapes():
# Should raise error if shapes do not match
module = DummyModule(training=False)
query = torch.ones(1, 1, 2, 3)
key = torch.ones(1, 1, 3, 3)
value = torch.ones(1, 1, 2, 3)
attention_mask = None
scaling = 1.0
dropout = 0.0
# key.shape[-2] != query.shape[-2], will cause matmul error
with pytest.raises(RuntimeError):
eager_attention_forward(module, query, key, value, attention_mask, scaling, dropout) # 123μs -> 115μs (6.88% faster)
def test_edge_dtype_consistency():
# Output dtype should match input query dtype
module = DummyModule(training=False)
query = torch.ones(1, 1, 2, 2, dtype=torch.float16)
key = torch.ones(1, 1, 2, 2, dtype=torch.float16)
value = torch.ones(1, 1, 2, 2, dtype=torch.float16)
attention_mask = None
scaling = 1.0
dropout = 0.0
attn_output, attn_weights = eager_attention_forward(
module, query, key, value, attention_mask, scaling, dropout
) # 80.8μs -> 73.4μs (10.1% faster)
===============================
LARGE SCALE TEST CASES
===============================
def test_large_scale_attention():
# Large batch, heads, sequence, but <100MB tensor size
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 8, 8, 64, 32
query = torch.randn(batch_size, num_heads, seq_len, head_dim)
key = torch.randn(batch_size, num_heads, seq_len, head_dim)
value = torch.randn(batch_size, num_heads, seq_len, head_dim)
attention_mask = torch.zeros(batch_size, num_heads, seq_len, seq_len)
scaling = 1.0 / (head_dim ** 0.5)
dropout = 0.0
def test_large_scale_masked_attention():
# Large scale with mask: block half the positions
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 4, 4, 32, 16
query = torch.randn(batch_size, num_heads, seq_len, head_dim)
key = torch.randn(batch_size, num_heads, seq_len, head_dim)
value = torch.randn(batch_size, num_heads, seq_len, head_dim)
# Mask out half positions
attention_mask = torch.zeros(batch_size, num_heads, seq_len, seq_len)
attention_mask[:,:,:,:seq_len//2] = -1e9
scaling = 1.0 / (head_dim ** 0.5)
dropout = 0.0
def test_large_scale_dropout_training():
# Large scale with dropout in training mode
module = DummyModule(training=True)
batch_size, num_heads, seq_len, head_dim = 2, 2, 32, 8
query = torch.randn(batch_size, num_heads, seq_len, head_dim)
key = torch.randn(batch_size, num_heads, seq_len, head_dim)
value = torch.randn(batch_size, num_heads, seq_len, head_dim)
attention_mask = torch.zeros(batch_size, num_heads, seq_len, seq_len)
scaling = 1.0 / (head_dim ** 0.5)
dropout = 0.2
def test_large_scale_different_dtypes():
# Large scale with float16 input
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 2, 2, 32, 8
query = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.float16)
key = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.float16)
value = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.float16)
attention_mask = torch.zeros(batch_size, num_heads, seq_len, seq_len, dtype=torch.float16)
scaling = 1.0 / (head_dim ** 0.5)
dropout = 0.0
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest # used for our unit tests
import torch
from torch import nn
from transformers.models.markuplm.modeling_markuplm import
eager_attention_forward
----------------------- UNIT TESTS -----------------------
Helper module for dropout
class DummyModule(nn.Module):
def init(self, training=False):
super().init()
self.training = training
----------- BASIC TEST CASES -----------
def test_basic_identity_attention():
# Test with query == key == value == identity, scaling=1.0, no mask, no dropout
batch, heads, seq_len, dim = 1, 1, 2, 2
module = DummyModule(training=False)
q = torch.eye(seq_len).reshape(batch, heads, seq_len, dim)
k = torch.eye(seq_len).reshape(batch, heads, seq_len, dim)
v = torch.eye(seq_len).reshape(batch, heads, seq_len, dim)
output, weights = eager_attention_forward(module, q, k, v, None, scaling=1.0) # 72.9μs -> 64.8μs (12.4% faster)
# For identity, attention should be uniform if not masked, so rows of weights should sum to 1
for b in range(batch):
for h in range(heads):
for i in range(seq_len):
row_sum = float(weights[b, h, i].sum())
def test_basic_masked_attention():
# Test with a simple mask that blocks one position
batch, heads, seq_len, dim = 1, 1, 3, 3
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
# Mask out last position for last query
mask = torch.zeros(batch, heads, seq_len, seq_len)
mask[:, :, 2, 2] = float('-inf')
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 74.3μs -> 63.3μs (17.4% faster)
# The last row should sum to 1 (softmax)
row_sum = float(weights[0, 0, 2].sum())
def test_basic_dropout_behavior():
# Test dropout disables when module.training=False and enables when True
batch, heads, seq_len, dim = 1, 1, 4, 4
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = None
# No dropout if not training
module = DummyModule(training=False)
out1, w1 = eager_attention_forward(module, q, k, v, mask, scaling=1.0, dropout=0.5) # 74.1μs -> 65.7μs (12.8% faster)
# Dropout if training
module = DummyModule(training=True)
out2, w2 = eager_attention_forward(module, q, k, v, mask, scaling=1.0, dropout=0.5) # 35.8μs -> 39.0μs (8.33% slower)
def test_basic_scaling_effect():
# Test that scaling affects the sharpness of softmax
batch, heads, seq_len, dim = 1, 1, 2, 2
module = DummyModule(training=False)
q = torch.ones(batch, heads, seq_len, dim)
k = torch.ones(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
# Large scaling should make softmax more peaky
_, w_small = eager_attention_forward(module, q, k, v, None, scaling=0.1) # 67.5μs -> 59.2μs (14.0% faster)
_, w_large = eager_attention_forward(module, q, k, v, None, scaling=10.0) # 20.1μs -> 16.9μs (19.4% faster)
# For large scaling, one value should be close to 1, the other close to 0
maxval = float(w_large[0, 0, 0].max())
minval = float(w_large[0, 0, 0].min())
----------- EDGE TEST CASES -----------
def test_empty_sequence():
# Test with zero-length sequence
batch, heads, seq_len, dim = 1, 1, 0, 4
module = DummyModule(training=False)
q = torch.empty(batch, heads, seq_len, dim)
k = torch.empty(batch, heads, seq_len, dim)
v = torch.empty(batch, heads, seq_len, dim)
mask = None
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 61.1μs -> 54.1μs (13.0% faster)
def test_singleton_batch_and_head():
# Test with batch=1, heads=1, seq_len=1, dim=1
module = DummyModule(training=False)
q = torch.tensor([[[[1.0]]]])
k = torch.tensor([[[[1.0]]]])
v = torch.tensor([[[[2.0]]]])
mask = None
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 69.0μs -> 60.5μs (13.9% faster)
def test_full_mask_out():
# Test where the mask blocks all positions for a query
batch, heads, seq_len, dim = 1, 1, 2, 2
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = torch.full((batch, heads, seq_len, seq_len), float('-inf'))
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 85.7μs -> 74.0μs (15.8% faster)
# All softmax outputs should be uniform (since all -inf, softmax is uniform by definition)
for i in range(seq_len):
for j in range(seq_len):
pass
def test_incorrect_mask_shape_raises():
# Mask shape should be (batch, heads, tgt_len, src_len)
batch, heads, seq_len, dim = 1, 1, 2, 2
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
# Intentionally wrong shape for mask
mask = torch.zeros(batch, heads, seq_len)
with pytest.raises(IndexError):
eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 86.9μs -> 86.1μs (0.913% faster)
def test_float16_and_float32_consistency():
# Test that float16 and float32 produce similar results (within tolerance)
batch, heads, seq_len, dim = 1, 1, 4, 4
module = DummyModule(training=False)
q32 = torch.randn(batch, heads, seq_len, dim, dtype=torch.float32)
k32 = torch.randn(batch, heads, seq_len, dim, dtype=torch.float32)
v32 = torch.randn(batch, heads, seq_len, dim, dtype=torch.float32)
q16 = q32.to(torch.float16)
k16 = k32.to(torch.float16)
v16 = v32.to(torch.float16)
out32, w32 = eager_attention_forward(module, q32, k32, v32, None, scaling=1.0) # 71.5μs -> 62.9μs (13.6% faster)
out16, w16 = eager_attention_forward(module, q16, k16, v16, None, scaling=1.0) # 26.6μs -> 24.3μs (9.41% faster)
----------- LARGE SCALE TEST CASES -----------
def test_large_batch_and_sequence():
# Test with large batch and sequence, but under 100MB
batch, heads, seq_len, dim = 8, 4, 32, 32
# Total elements: 8432324 = 131072, about 0.5MB per tensor (float32)
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = torch.zeros(batch, heads, seq_len, seq_len)
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 178μs -> 158μs (12.8% faster)
# Check normalization for a random row
row_sum = float(weights[0, 0, 0].sum())
def test_large_heads():
# Test with large number of heads
batch, heads, seq_len, dim = 2, 32, 16, 8
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = torch.zeros(batch, heads, seq_len, seq_len)
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 138μs -> 123μs (12.4% faster)
# Check normalization for a random head
row_sum = float(weights[1, 10, 5].sum())
def test_large_dim():
# Test with large embedding dimension
batch, heads, seq_len, dim = 1, 2, 8, 128
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = torch.zeros(batch, heads, seq_len, seq_len)
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 105μs -> 94.5μs (11.8% faster)
# Check normalization for a random position
row_sum = float(weights[0, 1, 3].sum())
def test_large_masked_attention():
# Test with large mask, masking out half the positions
batch, heads, seq_len, dim = 2, 2, 16, 16
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = torch.zeros(batch, heads, seq_len, seq_len)
mask[:, :, :, :8] = float('-inf') # mask out first half of keys
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 88.9μs -> 76.9μs (15.5% faster)
# Unmasked positions should sum to 1 for each row
for b in range(batch):
for h in range(heads):
for i in range(seq_len):
row_sum = float(weights[b, h, i, 8:].sum())
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-eager_attention_forward-mhvn1bdaand push.