Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 6, 2025

📄 10% (0.10x) speedup for contrastive_loss in src/transformers/models/clap/modeling_clap.py

⏱️ Runtime : 6.30 milliseconds 5.71 milliseconds (best of 152 runs)

📝 Explanation and details

The optimized version introduces tensor caching to eliminate redundant torch.arange() calls, which provides a 10% speedup by avoiding repeated tensor creation overhead.

Key Optimization:

  • Caches pre-computed label tensors using a function-level cache keyed by (batch_size, device)
  • Eliminates redundant tensor allocation when the same batch size and device are used repeatedly
  • Preserves exact functionality while reducing computational overhead

Why This Works:
The original code calls torch.arange(len(logits), device=logits.device) on every invocation, which creates a new tensor each time. PyTorch tensor creation involves memory allocation and device placement overhead. The optimization caches these label tensors since they're deterministic based on batch size and device.

Performance Benefits by Test Category:

  • Small tensors (1-8 batch size): 50-97% faster due to cache hits eliminating tensor creation overhead
  • Medium tensors (100x100): 48-55% faster as cache benefits outweigh lookup costs
  • Large tensors (512x512, 999x999): 3-11% faster since tensor creation becomes relatively smaller compared to cross-entropy computation

Impact Analysis:
Since contrastive loss is commonly used in training loops where the same batch sizes are processed repeatedly, this caching strategy is particularly effective. The cache grows bounded by the number of unique (batch_size, device) combinations encountered, making it memory-efficient for typical ML workloads.

The optimization is most beneficial for repeated calls with identical batch sizes, which is the common pattern during training epochs.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 41 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest  # used for our unit tests
import torch  # used for tensor creation and manipulation
# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html#CLIP-loss-function
from torch import nn
from transformers.models.clap.modeling_clap import contrastive_loss

# unit tests

# ----------- Basic Test Cases -------------

def test_basic_identity_logits():
    # Logits where each row is a one-hot vector for the correct class
    logits = torch.eye(4) * 100  # Large positive for the correct class
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 27.9μs -> 18.1μs (54.0% faster)

def test_basic_uniform_logits():
    # Logits are all zeros, so softmax is uniform
    logits = torch.zeros((5, 5))
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 35.5μs -> 19.0μs (87.2% faster)
    # With uniform probabilities, cross-entropy should be log(num_classes)
    expected_loss = torch.log(torch.tensor(5.0)).item()

def test_basic_random_logits():
    # Logits are random, but shape is valid
    torch.manual_seed(42)
    logits = torch.randn(3, 3)
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 34.0μs -> 22.5μs (51.0% faster)

# ----------- Edge Test Cases -------------

def test_edge_single_example():
    # Only one example, logits shape is (1, 1)
    logits = torch.tensor([[2.0]])
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 36.7μs -> 18.9μs (94.8% faster)



def test_edge_negative_logits():
    # Logits with negative values
    logits = -torch.eye(3) * 10
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 31.6μs -> 20.4μs (55.2% faster)

def test_edge_large_negative_logits():
    # Large negative logits for correct class, positive for others
    logits = torch.full((3, 3), 10.0)
    for i in range(3):
        logits[i, i] = -100.0
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 31.9μs -> 21.1μs (51.3% faster)

def test_edge_non_float_logits():
    # Logits as integers
    logits = torch.eye(3, dtype=torch.int32)
    with pytest.raises(RuntimeError):
        contrastive_loss(logits) # 87.3μs -> 75.3μs (15.9% faster)

def test_edge_device_consistency_cpu():
    # Logits on CPU
    logits = torch.eye(2)
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 37.4μs -> 23.5μs (59.5% faster)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")

def test_edge_inf_nan_logits():
    # Logits containing inf or nan should result in nan loss
    logits = torch.eye(2)
    logits[0, 0] = float('inf')
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 32.4μs -> 20.1μs (61.1% faster)

    logits = torch.eye(2)
    logits[1, 1] = float('nan')
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 8.43μs -> 4.87μs (73.0% faster)

# ----------- Large Scale Test Cases -------------

def test_large_scale_100x100():
    # Large batch and class size, but within memory limits
    torch.manual_seed(0)
    logits = torch.randn(100, 100)
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 44.0μs -> 28.3μs (55.5% faster)

def test_large_scale_999x999_perfect():
    # Large batch, perfect prediction (diagonal high values)
    logits = torch.eye(999) * 100
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 3.52ms -> 3.41ms (3.02% faster)

def test_large_scale_999x999_uniform():
    # Large batch, uniform logits
    logits = torch.zeros((999, 999))
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 610μs -> 609μs (0.214% faster)
    expected_loss = torch.log(torch.tensor(999.0)).item()

def test_large_scale_random_logits():
    # Large batch, random logits
    torch.manual_seed(123)
    logits = torch.randn(500, 500)
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 178μs -> 161μs (10.7% faster)

# ----------- Mutation-sensitive checks -------------

def test_mutation_sensitive_label_order():
    # If labels are reversed, loss should be much higher
    logits = torch.eye(5) * 10
    labels = torch.arange(4, -1, -1)
    codeflash_output = contrastive_loss(logits); loss_true = codeflash_output # 26.3μs -> 19.7μs (33.2% faster)
    loss_wrong = nn.functional.cross_entropy(logits, labels)

def test_mutation_sensitive_wrong_labels():
    # If labels are all incorrect, loss should be high
    logits = torch.eye(6) * 10
    labels = torch.full((6,), 0)
    codeflash_output = contrastive_loss(logits); loss_true = codeflash_output # 26.4μs -> 16.6μs (59.3% faster)
    loss_wrong = nn.functional.cross_entropy(logits, labels)

# ----------- Determinism -------------

def test_determinism():
    # Repeated calls with same input should yield same output
    torch.manual_seed(777)
    logits = torch.randn(8, 8)
    codeflash_output = contrastive_loss(logits); loss1 = codeflash_output # 36.8μs -> 21.4μs (71.6% faster)
    codeflash_output = contrastive_loss(logits); loss2 = codeflash_output # 8.86μs -> 5.13μs (72.8% faster)

# ----------- Documentation and Readability -------------

def test_loss_is_scalar_and_on_same_device():
    # Loss should be scalar and on same device as input
    logits = torch.eye(4)
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 32.4μs -> 17.5μs (85.6% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest  # used for our unit tests
# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html#CLIP-loss-function
import torch  # used for tensor operations
from torch import nn
from transformers.models.clap.modeling_clap import contrastive_loss

# unit tests

# --- Basic Test Cases ---

def test_basic_identity_logits():
    # Test with identity matrix logits: highest score at correct class
    logits = torch.eye(4) * 10  # Large positive on diagonal
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 28.8μs -> 19.1μs (50.4% faster)

def test_basic_uniform_logits():
    # Test with all logits equal: model is maximally uncertain
    logits = torch.zeros((3, 3))
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 36.9μs -> 22.5μs (63.8% faster)
    # Cross-entropy for uniform logits is log(num_classes)
    expected = torch.log(torch.tensor(3.0))

def test_basic_batch_size_1():
    # Test with batch size 1
    logits = torch.tensor([[2.0, 1.0, 0.0]])
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 36.9μs -> 18.7μs (97.1% faster)
    # Should match cross-entropy for correct label 0
    expected = nn.functional.cross_entropy(logits, torch.tensor([0]))

def test_basic_random_logits():
    # Test with random logits, check deterministic output
    torch.manual_seed(42)
    logits = torch.randn(5, 5)
    codeflash_output = contrastive_loss(logits); loss1 = codeflash_output # 36.5μs -> 21.1μs (72.9% faster)
    torch.manual_seed(42)
    logits2 = torch.randn(5, 5)
    codeflash_output = contrastive_loss(logits2); loss2 = codeflash_output # 10.9μs -> 6.44μs (69.2% faster)

# --- Edge Test Cases ---


def test_edge_non_square_logits():
    # Test with non-square logits (batch_size != num_classes)
    logits = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
    # Here, batch_size=2, num_classes=2, so labels=[0,1], should work
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 43.9μs -> 27.0μs (62.7% faster)

    # Now, batch_size=2, num_classes=3, labels=[0,1], should work
    logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 9.14μs -> 5.05μs (81.0% faster)

def test_edge_single_class():
    # Test with single class (degenerate case)
    logits = torch.tensor([[10.0]])
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 35.2μs -> 18.5μs (89.8% faster)

def test_edge_negative_logits():
    # Test with negative logits
    logits = -torch.eye(3) * 5
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 27.3μs -> 16.8μs (62.0% faster)

def test_edge_large_negative_logits():
    # Test with large negative logits (should not produce NaN)
    logits = torch.full((4, 4), -1e6)
    # Set diagonal to 0 so correct class is less negative
    for i in range(4):
        logits[i, i] = 0
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 30.8μs -> 19.5μs (58.2% faster)

def test_edge_cuda_cpu_consistency():
    # Test that loss is the same on CPU and CUDA (if available)
    logits = torch.randn(8, 8)
    codeflash_output = contrastive_loss(logits); loss_cpu = codeflash_output # 37.0μs -> 22.0μs (68.4% faster)
    if torch.cuda.is_available():
        logits_cuda = logits.cuda()
        codeflash_output = contrastive_loss(logits_cuda); loss_cuda = codeflash_output

def test_edge_dtype_consistency():
    # Test with float16 and float32
    logits = torch.randn(6, 6)
    codeflash_output = contrastive_loss(logits.float()); loss32 = codeflash_output # 35.9μs -> 18.7μs (91.7% faster)
    codeflash_output = contrastive_loss(logits.half()); loss16 = codeflash_output # 10.6μs -> 7.02μs (50.7% faster)

def test_edge_requires_grad():
    # Test that loss is differentiable
    logits = torch.randn(4, 4, requires_grad=True)
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 41.7μs -> 28.8μs (45.0% faster)
    loss.backward()

def test_edge_non_contiguous_input():
    # Test with non-contiguous input tensor
    logits = torch.randn(10, 10).transpose(0, 1)
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 46.2μs -> 35.0μs (32.1% faster)

# --- Large Scale Test Cases ---

def test_large_scale_100x100():
    # Test with large batch and class count (100x100)
    logits = torch.randn(100, 100)
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 43.9μs -> 29.1μs (50.7% faster)

def test_large_scale_512x512():
    # Test with maximum allowed size (512x512, ~1MB)
    logits = torch.randn(512, 512)
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 172μs -> 155μs (10.8% faster)

def test_large_scale_performance():
    # Test that function runs quickly on large input (timing not enforced, but should not hang)
    logits = torch.randn(999, 999)
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 639μs -> 610μs (4.71% faster)

def test_large_scale_high_values():
    # Test with very large positive values
    logits = torch.full((50, 50), 1e6)
    # Set diagonal to 2e6 so correct class is even higher
    for i in range(50):
        logits[i, i] = 2e6
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 36.6μs -> 24.9μs (47.0% faster)

def test_large_scale_low_values():
    # Test with very large negative values
    logits = torch.full((50, 50), -1e6)
    # Set diagonal to 0 so correct class is less negative
    for i in range(50):
        logits[i, i] = 0
    codeflash_output = contrastive_loss(logits); loss = codeflash_output # 34.9μs -> 23.6μs (47.7% faster)

def test_large_scale_random_seed_consistency():
    # Test that random seed produces deterministic results for large input
    torch.manual_seed(1234)
    logits = torch.randn(100, 100)
    codeflash_output = contrastive_loss(logits); loss1 = codeflash_output # 43.5μs -> 29.4μs (48.0% faster)
    torch.manual_seed(1234)
    logits2 = torch.randn(100, 100)
    codeflash_output = contrastive_loss(logits2); loss2 = codeflash_output # 17.5μs -> 12.4μs (41.4% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-contrastive_loss-mhmshnx3 and push.

Codeflash Static Badge

The optimized version introduces **tensor caching** to eliminate redundant `torch.arange()` calls, which provides a 10% speedup by avoiding repeated tensor creation overhead.

**Key Optimization:**
- **Caches pre-computed label tensors** using a function-level cache keyed by `(batch_size, device)`
- **Eliminates redundant tensor allocation** when the same batch size and device are used repeatedly
- **Preserves exact functionality** while reducing computational overhead

**Why This Works:**
The original code calls `torch.arange(len(logits), device=logits.device)` on every invocation, which creates a new tensor each time. PyTorch tensor creation involves memory allocation and device placement overhead. The optimization caches these label tensors since they're deterministic based on batch size and device.

**Performance Benefits by Test Category:**
- **Small tensors (1-8 batch size)**: 50-97% faster due to cache hits eliminating tensor creation overhead
- **Medium tensors (100x100)**: 48-55% faster as cache benefits outweigh lookup costs  
- **Large tensors (512x512, 999x999)**: 3-11% faster since tensor creation becomes relatively smaller compared to cross-entropy computation

**Impact Analysis:**
Since contrastive loss is commonly used in training loops where the same batch sizes are processed repeatedly, this caching strategy is particularly effective. The cache grows bounded by the number of unique `(batch_size, device)` combinations encountered, making it memory-efficient for typical ML workloads.

The optimization is most beneficial for repeated calls with identical batch sizes, which is the common pattern during training epochs.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 6, 2025 02:10
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant