Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 12, 2025

📄 12% (0.12x) speedup for eager_attention_forward in src/transformers/models/markuplm/modeling_markuplm.py

⏱️ Runtime : 3.19 milliseconds 2.84 milliseconds (best of 250 runs)

📝 Explanation and details

The optimized code achieves a 12% speedup through several key micro-optimizations that reduce overhead and memory operations:

What optimizations were applied:

  1. Cached frequently accessed attributes - Store nn.functional, query.dtype, and key.shape[-2] in local variables to avoid repeated attribute lookups
  2. In-place operations - Replace * scaling with attn_weights.mul_(scaling) and + causal_mask with attn_weights.add_(causal_mask) to reduce memory allocation
  3. Conditional dtype conversion - Only call .to(query.dtype) when the query isn't already float32, avoiding unnecessary type conversion overhead
  4. Smart dropout handling - Skip the nn.functional.dropout call entirely when dropout=0.0 or module.training=False, eliminating a no-op function call

Why these optimizations work:

  • Attribute lookup reduction: Python attribute access has overhead; caching nn.functional and tensor properties eliminates repeated dictionary lookups
  • In-place operations: mul_() and add_() modify tensors without creating new ones, reducing memory bandwidth and allocation overhead
  • Avoiding no-op calls: The original code always called dropout regardless of parameters; the optimized version skips this when unnecessary

Performance characteristics from test results:

  • Consistently 10-19% faster across most test cases, with particularly strong gains on larger tensors (17.6% on large-scale attention)
  • Minimal impact on training cases with actual dropout (2.99% faster vs 11.8% slower in one case due to branching overhead)
  • Best performance gains on basic operations without dropout where the overhead reduction is most apparent

The optimizations are particularly effective for transformer attention layers where this function is called frequently in inference scenarios with no dropout.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 32 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime

import pytest # used for our unit tests
import torch
from torch import nn
from transformers.models.markuplm.modeling_markuplm import
eager_attention_forward

unit tests

Helper to create dummy module

class DummyModule(nn.Module):
def init(self, training=True):
super().init()
self.training = training

===============================

BASIC TEST CASES

===============================

def test_basic_identity_attention():
# Test with identity query/key/value and no mask, scaling=1, dropout=0
# Should produce softmax of identity matrix (i.e., one-hot along last dim)
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 3, 3
query = torch.eye(seq_len).reshape(1, 1, seq_len, seq_len)
key = torch.eye(seq_len).reshape(1, 1, seq_len, seq_len)
value = torch.arange(seq_len).float().reshape(1, 1, seq_len, 1)
attention_mask = None
scaling = 1.0
dropout = 0.0

attn_output, attn_weights = eager_attention_forward(
    module, query, key, value, attention_mask, scaling, dropout
) # 60.3μs -> 52.1μs (15.8% faster)
# attn_weights should be one-hot along last dim
expected_weights = torch.eye(seq_len).reshape(1, 1, seq_len, seq_len)
# attn_output should select the value for each position
expected_output = value.transpose(2, 3)

def test_basic_uniform_attention():
# All query/key are zeros, so softmax should be uniform
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 2, 1, 4, 5
query = torch.zeros(batch_size, num_heads, seq_len, head_dim)
key = torch.zeros(batch_size, num_heads, seq_len, head_dim)
value = torch.arange(seq_len).float().reshape(1, 1, seq_len, 1).expand(batch_size, num_heads, seq_len, 1)
attention_mask = None
scaling = 1.0
dropout = 0.0

attn_output, attn_weights = eager_attention_forward(
    module, query, key, value, attention_mask, scaling, dropout
) # 56.8μs -> 48.4μs (17.4% faster)
# attn_weights should be uniform along last dim
expected_weights = torch.full((batch_size, num_heads, seq_len, seq_len), 1.0/seq_len)
# attn_output should be mean of values along last dim
expected_output = value.mean(dim=2, keepdim=True).expand(batch_size, seq_len, num_heads, 1).transpose(1,2)

def test_basic_with_attention_mask():
# Test with attention mask that blocks some positions
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 2, 2
query = torch.ones(batch_size, num_heads, seq_len, head_dim)
key = torch.ones(batch_size, num_heads, seq_len, head_dim)
value = torch.tensor([[[[1],[2]],[[3],[4]]]], dtype=torch.float32)
# Mask blocks position 1 for all queries
attention_mask = torch.tensor([[[[0, -1e9],[0, -1e9]]]], dtype=torch.float32)
scaling = 1.0
dropout = 0.0

attn_output, attn_weights = eager_attention_forward(
    module, query, key, value, attention_mask, scaling, dropout
) # 88.7μs -> 79.1μs (12.2% faster)
# Only position 0 should be attended for both queries
expected_weights = torch.tensor([[[[1,0],[1,0]]]], dtype=torch.float32)
# Output should be value at position 0 for both queries
expected_output = torch.tensor([[[[1],[3]],[[1],[3]]]], dtype=torch.float32).transpose(1,2)

def test_basic_dropout_training_mode():
# Dropout should be applied only in training mode
module_train = DummyModule(training=True)
module_eval = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 2, 2
query = torch.ones(batch_size, num_heads, seq_len, head_dim)
key = torch.ones(batch_size, num_heads, seq_len, head_dim)
value = torch.ones(batch_size, num_heads, seq_len, head_dim)
attention_mask = None
scaling = 1.0
dropout = 0.5

# In eval mode, dropout should not change output
attn_output_eval, attn_weights_eval = eager_attention_forward(
    module_eval, query, key, value, attention_mask, scaling, dropout
) # 70.6μs -> 61.6μs (14.5% faster)
attn_output_no_dropout, attn_weights_no_dropout = eager_attention_forward(
    module_eval, query, key, value, attention_mask, scaling, 0.0
) # 20.6μs -> 17.2μs (19.7% faster)

# In training mode, dropout should change output (with high probability)
attn_output_train, attn_weights_train = eager_attention_forward(
    module_train, query, key, value, attention_mask, scaling, dropout
) # 30.9μs -> 35.1μs (11.8% slower)

===============================

EDGE TEST CASES

===============================

def test_edge_zero_length_sequence():
# Sequence length 0: should return empty tensors
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 0, 4
query = torch.empty(batch_size, num_heads, seq_len, head_dim)
key = torch.empty(batch_size, num_heads, seq_len, head_dim)
value = torch.empty(batch_size, num_heads, seq_len, head_dim)
attention_mask = None
scaling = 1.0
dropout = 0.0

attn_output, attn_weights = eager_attention_forward(
    module, query, key, value, attention_mask, scaling, dropout
) # 60.3μs -> 51.4μs (17.4% faster)

def test_edge_single_element():
# Sequence length 1: softmax should be 1, output should be value
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 1, 3
query = torch.ones(batch_size, num_heads, seq_len, head_dim)
key = torch.ones(batch_size, num_heads, seq_len, head_dim)
value = torch.tensor([[[[5, 6, 7]]]], dtype=torch.float32)
attention_mask = None
scaling = 1.0
dropout = 0.0

attn_output, attn_weights = eager_attention_forward(
    module, query, key, value, attention_mask, scaling, dropout
) # 65.4μs -> 56.3μs (16.3% faster)

def test_edge_high_scaling():
# Large scaling should push softmax to be very peaked
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 3, 3
query = torch.tensor([[[[10,0,0],[0,10,0],[0,0,10]]]], dtype=torch.float32)
key = torch.eye(seq_len).reshape(1, 1, seq_len, seq_len)
value = torch.arange(seq_len).float().reshape(1, 1, seq_len, 1)
attention_mask = None
scaling = 100.0
dropout = 0.0

attn_output, attn_weights = eager_attention_forward(
    module, query, key, value, attention_mask, scaling, dropout
) # 59.6μs -> 50.4μs (18.3% faster)
# Each query should attend almost exclusively to its matching key
expected_weights = torch.eye(seq_len).reshape(1, 1, seq_len, seq_len)
# Output should be value at each position
expected_output = value.transpose(2,3)

def test_edge_negative_mask():
# Attention mask with large negative values should block attention
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 1, 1, 2, 2
query = torch.ones(batch_size, num_heads, seq_len, head_dim)
key = torch.ones(batch_size, num_heads, seq_len, head_dim)
value = torch.tensor([[[[1],[2]],[[3],[4]]]], dtype=torch.float32)
attention_mask = torch.tensor([[[[0, -1e9],[0, -1e9]]]], dtype=torch.float32)
scaling = 1.0
dropout = 0.0

attn_output, attn_weights = eager_attention_forward(
    module, query, key, value, attention_mask, scaling, dropout
) # 90.4μs -> 80.0μs (13.0% faster)
# Only position 0 should be attended for both queries
expected_weights = torch.tensor([[[[1,0],[1,0]]]], dtype=torch.float32)

def test_edge_incorrect_shapes():
# Should raise error if shapes do not match
module = DummyModule(training=False)
query = torch.ones(1, 1, 2, 3)
key = torch.ones(1, 1, 3, 3)
value = torch.ones(1, 1, 2, 3)
attention_mask = None
scaling = 1.0
dropout = 0.0
# key.shape[-2] != query.shape[-2], will cause matmul error
with pytest.raises(RuntimeError):
eager_attention_forward(module, query, key, value, attention_mask, scaling, dropout) # 123μs -> 115μs (6.88% faster)

def test_edge_dtype_consistency():
# Output dtype should match input query dtype
module = DummyModule(training=False)
query = torch.ones(1, 1, 2, 2, dtype=torch.float16)
key = torch.ones(1, 1, 2, 2, dtype=torch.float16)
value = torch.ones(1, 1, 2, 2, dtype=torch.float16)
attention_mask = None
scaling = 1.0
dropout = 0.0
attn_output, attn_weights = eager_attention_forward(
module, query, key, value, attention_mask, scaling, dropout
) # 80.8μs -> 73.4μs (10.1% faster)

===============================

LARGE SCALE TEST CASES

===============================

def test_large_scale_attention():
# Large batch, heads, sequence, but <100MB tensor size
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 8, 8, 64, 32
query = torch.randn(batch_size, num_heads, seq_len, head_dim)
key = torch.randn(batch_size, num_heads, seq_len, head_dim)
value = torch.randn(batch_size, num_heads, seq_len, head_dim)
attention_mask = torch.zeros(batch_size, num_heads, seq_len, seq_len)
scaling = 1.0 / (head_dim ** 0.5)
dropout = 0.0

attn_output, attn_weights = eager_attention_forward(
    module, query, key, value, attention_mask, scaling, dropout
) # 680μs -> 578μs (17.6% faster)
# Check that softmax sums to 1 along last dim
sums = attn_weights.sum(dim=-1)

def test_large_scale_masked_attention():
# Large scale with mask: block half the positions
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 4, 4, 32, 16
query = torch.randn(batch_size, num_heads, seq_len, head_dim)
key = torch.randn(batch_size, num_heads, seq_len, head_dim)
value = torch.randn(batch_size, num_heads, seq_len, head_dim)
# Mask out half positions
attention_mask = torch.zeros(batch_size, num_heads, seq_len, seq_len)
attention_mask[:,:,:,:seq_len//2] = -1e9
scaling = 1.0 / (head_dim ** 0.5)
dropout = 0.0

attn_output, attn_weights = eager_attention_forward(
    module, query, key, value, attention_mask, scaling, dropout
) # 113μs -> 100μs (12.6% faster)
# The masked positions should be zero in softmax
masked = attn_weights[:,:,:,:seq_len//2]

def test_large_scale_dropout_training():
# Large scale with dropout in training mode
module = DummyModule(training=True)
batch_size, num_heads, seq_len, head_dim = 2, 2, 32, 8
query = torch.randn(batch_size, num_heads, seq_len, head_dim)
key = torch.randn(batch_size, num_heads, seq_len, head_dim)
value = torch.randn(batch_size, num_heads, seq_len, head_dim)
attention_mask = torch.zeros(batch_size, num_heads, seq_len, seq_len)
scaling = 1.0 / (head_dim ** 0.5)
dropout = 0.2

attn_output, attn_weights = eager_attention_forward(
    module, query, key, value, attention_mask, scaling, dropout
) # 153μs -> 149μs (2.99% faster)
# Dropout should zero out some weights (with high probability)
num_zeros = (attn_weights == 0).sum().item()

def test_large_scale_different_dtypes():
# Large scale with float16 input
module = DummyModule(training=False)
batch_size, num_heads, seq_len, head_dim = 2, 2, 32, 8
query = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.float16)
key = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.float16)
value = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=torch.float16)
attention_mask = torch.zeros(batch_size, num_heads, seq_len, seq_len, dtype=torch.float16)
scaling = 1.0 / (head_dim ** 0.5)
dropout = 0.0

attn_output, attn_weights = eager_attention_forward(
    module, query, key, value, attention_mask, scaling, dropout
) # 181μs -> 169μs (6.86% faster)

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

#------------------------------------------------
import pytest # used for our unit tests
import torch
from torch import nn
from transformers.models.markuplm.modeling_markuplm import
eager_attention_forward

----------------------- UNIT TESTS -----------------------

Helper module for dropout

class DummyModule(nn.Module):
def init(self, training=False):
super().init()
self.training = training

----------- BASIC TEST CASES -----------

def test_basic_identity_attention():
# Test with query == key == value == identity, scaling=1.0, no mask, no dropout
batch, heads, seq_len, dim = 1, 1, 2, 2
module = DummyModule(training=False)
q = torch.eye(seq_len).reshape(batch, heads, seq_len, dim)
k = torch.eye(seq_len).reshape(batch, heads, seq_len, dim)
v = torch.eye(seq_len).reshape(batch, heads, seq_len, dim)
output, weights = eager_attention_forward(module, q, k, v, None, scaling=1.0) # 72.9μs -> 64.8μs (12.4% faster)
# For identity, attention should be uniform if not masked, so rows of weights should sum to 1
for b in range(batch):
for h in range(heads):
for i in range(seq_len):
row_sum = float(weights[b, h, i].sum())

def test_basic_masked_attention():
# Test with a simple mask that blocks one position
batch, heads, seq_len, dim = 1, 1, 3, 3
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
# Mask out last position for last query
mask = torch.zeros(batch, heads, seq_len, seq_len)
mask[:, :, 2, 2] = float('-inf')
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 74.3μs -> 63.3μs (17.4% faster)
# The last row should sum to 1 (softmax)
row_sum = float(weights[0, 0, 2].sum())

def test_basic_dropout_behavior():
# Test dropout disables when module.training=False and enables when True
batch, heads, seq_len, dim = 1, 1, 4, 4
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = None
# No dropout if not training
module = DummyModule(training=False)
out1, w1 = eager_attention_forward(module, q, k, v, mask, scaling=1.0, dropout=0.5) # 74.1μs -> 65.7μs (12.8% faster)
# Dropout if training
module = DummyModule(training=True)
out2, w2 = eager_attention_forward(module, q, k, v, mask, scaling=1.0, dropout=0.5) # 35.8μs -> 39.0μs (8.33% slower)

def test_basic_scaling_effect():
# Test that scaling affects the sharpness of softmax
batch, heads, seq_len, dim = 1, 1, 2, 2
module = DummyModule(training=False)
q = torch.ones(batch, heads, seq_len, dim)
k = torch.ones(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
# Large scaling should make softmax more peaky
_, w_small = eager_attention_forward(module, q, k, v, None, scaling=0.1) # 67.5μs -> 59.2μs (14.0% faster)
_, w_large = eager_attention_forward(module, q, k, v, None, scaling=10.0) # 20.1μs -> 16.9μs (19.4% faster)
# For large scaling, one value should be close to 1, the other close to 0
maxval = float(w_large[0, 0, 0].max())
minval = float(w_large[0, 0, 0].min())

----------- EDGE TEST CASES -----------

def test_empty_sequence():
# Test with zero-length sequence
batch, heads, seq_len, dim = 1, 1, 0, 4
module = DummyModule(training=False)
q = torch.empty(batch, heads, seq_len, dim)
k = torch.empty(batch, heads, seq_len, dim)
v = torch.empty(batch, heads, seq_len, dim)
mask = None
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 61.1μs -> 54.1μs (13.0% faster)

def test_singleton_batch_and_head():
# Test with batch=1, heads=1, seq_len=1, dim=1
module = DummyModule(training=False)
q = torch.tensor([[[[1.0]]]])
k = torch.tensor([[[[1.0]]]])
v = torch.tensor([[[[2.0]]]])
mask = None
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 69.0μs -> 60.5μs (13.9% faster)

def test_full_mask_out():
# Test where the mask blocks all positions for a query
batch, heads, seq_len, dim = 1, 1, 2, 2
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = torch.full((batch, heads, seq_len, seq_len), float('-inf'))
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 85.7μs -> 74.0μs (15.8% faster)
# All softmax outputs should be uniform (since all -inf, softmax is uniform by definition)
for i in range(seq_len):
for j in range(seq_len):
pass

def test_incorrect_mask_shape_raises():
# Mask shape should be (batch, heads, tgt_len, src_len)
batch, heads, seq_len, dim = 1, 1, 2, 2
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
# Intentionally wrong shape for mask
mask = torch.zeros(batch, heads, seq_len)
with pytest.raises(IndexError):
eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 86.9μs -> 86.1μs (0.913% faster)

def test_float16_and_float32_consistency():
# Test that float16 and float32 produce similar results (within tolerance)
batch, heads, seq_len, dim = 1, 1, 4, 4
module = DummyModule(training=False)
q32 = torch.randn(batch, heads, seq_len, dim, dtype=torch.float32)
k32 = torch.randn(batch, heads, seq_len, dim, dtype=torch.float32)
v32 = torch.randn(batch, heads, seq_len, dim, dtype=torch.float32)
q16 = q32.to(torch.float16)
k16 = k32.to(torch.float16)
v16 = v32.to(torch.float16)
out32, w32 = eager_attention_forward(module, q32, k32, v32, None, scaling=1.0) # 71.5μs -> 62.9μs (13.6% faster)
out16, w16 = eager_attention_forward(module, q16, k16, v16, None, scaling=1.0) # 26.6μs -> 24.3μs (9.41% faster)

----------- LARGE SCALE TEST CASES -----------

def test_large_batch_and_sequence():
# Test with large batch and sequence, but under 100MB
batch, heads, seq_len, dim = 8, 4, 32, 32
# Total elements: 8432324 = 131072, about 0.5MB per tensor (float32)
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = torch.zeros(batch, heads, seq_len, seq_len)
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 178μs -> 158μs (12.8% faster)
# Check normalization for a random row
row_sum = float(weights[0, 0, 0].sum())

def test_large_heads():
# Test with large number of heads
batch, heads, seq_len, dim = 2, 32, 16, 8
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = torch.zeros(batch, heads, seq_len, seq_len)
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 138μs -> 123μs (12.4% faster)
# Check normalization for a random head
row_sum = float(weights[1, 10, 5].sum())

def test_large_dim():
# Test with large embedding dimension
batch, heads, seq_len, dim = 1, 2, 8, 128
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = torch.zeros(batch, heads, seq_len, seq_len)
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 105μs -> 94.5μs (11.8% faster)
# Check normalization for a random position
row_sum = float(weights[0, 1, 3].sum())

def test_large_masked_attention():
# Test with large mask, masking out half the positions
batch, heads, seq_len, dim = 2, 2, 16, 16
module = DummyModule(training=False)
q = torch.randn(batch, heads, seq_len, dim)
k = torch.randn(batch, heads, seq_len, dim)
v = torch.randn(batch, heads, seq_len, dim)
mask = torch.zeros(batch, heads, seq_len, seq_len)
mask[:, :, :, :8] = float('-inf') # mask out first half of keys
output, weights = eager_attention_forward(module, q, k, v, mask, scaling=1.0) # 88.9μs -> 76.9μs (15.5% faster)
# Unmasked positions should sum to 1 for each row
for b in range(batch):
for h in range(heads):
for i in range(seq_len):
row_sum = float(weights[b, h, i, 8:].sum())

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-eager_attention_forward-mhvn1bda and push.

Codeflash Static Badge

The optimized code achieves a **12% speedup** through several key micro-optimizations that reduce overhead and memory operations:

**What optimizations were applied:**
1. **Cached frequently accessed attributes** - Store `nn.functional`, `query.dtype`, and `key.shape[-2]` in local variables to avoid repeated attribute lookups
2. **In-place operations** - Replace `* scaling` with `attn_weights.mul_(scaling)` and `+ causal_mask` with `attn_weights.add_(causal_mask)` to reduce memory allocation
3. **Conditional dtype conversion** - Only call `.to(query.dtype)` when the query isn't already float32, avoiding unnecessary type conversion overhead
4. **Smart dropout handling** - Skip the `nn.functional.dropout` call entirely when `dropout=0.0` or `module.training=False`, eliminating a no-op function call

**Why these optimizations work:**
- **Attribute lookup reduction**: Python attribute access has overhead; caching `nn.functional` and tensor properties eliminates repeated dictionary lookups
- **In-place operations**: `mul_()` and `add_()` modify tensors without creating new ones, reducing memory bandwidth and allocation overhead
- **Avoiding no-op calls**: The original code always called dropout regardless of parameters; the optimized version skips this when unnecessary

**Performance characteristics from test results:**
- Consistently **10-19% faster** across most test cases, with particularly strong gains on larger tensors (17.6% on large-scale attention)
- Minimal impact on training cases with actual dropout (2.99% faster vs 11.8% slower in one case due to branching overhead)
- Best performance gains on basic operations without dropout where the overhead reduction is most apparent

The optimizations are particularly effective for transformer attention layers where this function is called frequently in inference scenarios with no dropout.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 12, 2025 06:47
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant