Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "attention_smithy"
version = "1.1.1"
version = "1.2.0"
authors = [
{ name="Caleb Cranney", email="11773171+CCranney@users.noreply.github.com" },
]
Expand Down
3 changes: 2 additions & 1 deletion src/attention_smithy/numeric_embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@

from attention_smithy.numeric_embeddings.specialized.SinusoidalCustomEmbedding import SinusoidalCustomEmbedding
from attention_smithy.numeric_embeddings.specialized.ALiBiCustomEmbedding import ALiBiCustomEmbedding

from attention_smithy.numeric_embeddings.specialized.RotaryCustomEmbedding import RotaryCustomEmbedding
from attention_smithy.numeric_embeddings.specialized.ContinuousValueEmbedding import ContinuousValueEmbedding
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ def create_positional_or_custom_embedding(self, **kwargs) -> torch.Tensor:
"""
pass

from abc import ABC, abstractmethod

class MatrixModificationStrategyBase(ABC, nn.Module):
"""
Abstract base class for strategies that modify or transform matrices.
Expand All @@ -38,8 +36,6 @@ def modify_matrix(self, target_matrix, **kwargs) -> torch.Tensor:
"""
pass

from abc import ABC, abstractmethod

class AttentionBiasStrategyBase(ABC, nn.Module):
"""
Abstract base class for strategies that generate bias tensors to be added to attention score matrices.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from torch import nn
from attention_smithy.numeric_embeddings.abstract_embedding_strategies import NumericEmbeddingStrategyBase

class ContinuousValueEmbedding(NumericEmbeddingStrategyBase):
"""
Encodes continuous scalar values into a high-dimensional embedding space using
a two-layer MLP with a tanh non-linearity.
This approach allows numeric values (e.g., scalar features or non-position numbers)
to be encoded in a learnable, non-linear way.

See https://doi.org/10.1145/3516367.
"""

def __init__(self, embedding_dimension: int) -> None:
super().__init__()
hidden_dim = int(embedding_dimension ** 0.5)
self.net = nn.Sequential(
nn.Linear(1, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, embedding_dimension, bias=False)
)

def forward(self, continuous_values: torch.Tensor, **kwargs) -> torch.Tensor:
continuous_values = continuous_values.unsqueeze(-1)
return self.net(continuous_values)

def create_positional_or_custom_embedding(self, continuous_values: torch.Tensor, **kwargs) -> torch.Tensor:
return self(continuous_values, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from torch import nn
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
from einops import rearrange
from typing import Optional
from attention_smithy.numeric_embeddings.abstract_embedding_strategies import MatrixModificationStrategyBase

class RotaryCustomEmbedding(MatrixModificationStrategyBase):
"""
Rotary Embedding wrapper that allows applying custom scalar position values
(e.g., [0.4, 2.6, 10.1]) instead of standard integer positions (0, 1, 2, ...).

Applies rotary embeddings to input tensors like queries or keys using those custom positions.
"""

def __init__(self, head_dimension: int):
"""
Args:
head_dimension (int): Embedding dimension per attention head (must be divisible by 2).
"""
super().__init__()
self.rotary = RotaryEmbedding(dim=head_dimension, cache_if_possible=False)

def modify_matrix(self, target_matrix: torch.Tensor, **kwargs) -> torch.Tensor:
return self.forward(target_matrix, **kwargs)

def forward(
self,
target_matrix: torch.Tensor,
rotary_custom_values: torch.Tensor,
seq_dim: Optional[int] = None
) -> torch.Tensor:
"""
Args:
target_matrix (torch.Tensor): The input Q or K matrix of shape (..., seq_len, dim).
rotary_custom_values (torch.Tensor): Float tensor of shape (seq_len,) specifying custom position values.
seq_dim (int, optional): The dimension in `tensor` corresponding to sequence length. Defaults to -2.

Returns:
torch.Tensor: Tensor with rotary embeddings applied.
"""
seq_dim = seq_dim if seq_dim is not None else -2
seq_len = target_matrix.shape[seq_dim]

if rotary_custom_values.shape[0] != seq_len:
raise ValueError(f"rotary_custom_values length {rotary_custom_values.shape[0]} must match sequence length {seq_len}")

freqs = self.rotary.forward(rotary_custom_values, seq_len=seq_len)

if seq_dim == -3:
freqs = rearrange(freqs, 'n d -> n 1 d')

return apply_rotary_emb(freqs, target_matrix, seq_dim=seq_dim)
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def test__ALiBiCustomEmbedding__numHeads2():
[-1.0, 0.0, -2.0, -0.5, -1.5],
],
[
[-0.25, -0.25, -0.75, 0.00, -0.50],
[-0.25, -0.75, -0.25, -0.50, 0.00],
[-0.50, -1.00, 0.00, -0.75, -0.25],
[0.00, -0.50, -0.50, -0.25, -0.25],
[-0.50, 0.00, -1.00, -0.25, -0.75],
[-0.25, -0.25, -0.75, 0.00, -0.50],
[-0.25, -0.75, -0.25, -0.50, 0.00],
[-0.50, -1.00, 0.00, -0.75, -0.25],
[ 0.00, -0.50, -0.50, -0.25, -0.25],
[-0.50, 0.00, -1.00, -0.25, -0.75],
],
])[None, :, :, :]
output = embedding(query_values, key_values, None)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
import pytest
from attention_smithy.numeric_embeddings import ContinuousValueEmbedding

@pytest.fixture
def model():
return ContinuousValueEmbedding(embedding_dimension=64)

def test__ContinuousValueEmbedding__output_shape_matches_expected(model):
numeric_values = torch.randn(8, 10) # batch_size=8, seq_len=10
output = model(numeric_values)
assert output.shape == (8, 10, 64)

def test__ContinuousValueEmbedding__runs_on_different_input_shapes(model):
for bsz, seq_len in [(1, 1), (4, 7), (16, 32)]:
numeric_values = torch.randn(bsz, seq_len)
output = model(numeric_values)
assert output.shape == (bsz, seq_len, 64)

def test__ContinuousValueEmbedding__has_expected_number_of_parameters():
model = ContinuousValueEmbedding(embedding_dimension=64)
params = list(model.parameters())
assert len(params) == 3
assert model.net[2].bias is None

def test__ContinuousValueEmbedding__create_embedding_matches_forward(model):
numeric_values = torch.randn(3, 5)
output_from_forward = model(numeric_values)
output_from_create = model.create_positional_or_custom_embedding(numeric_values)
assert torch.allclose(output_from_forward, output_from_create)

def test__ContinuousValueEmbedding__gradient_flow(model):
numeric_values = torch.randn(2, 4, requires_grad=True)
output = model(numeric_values)
loss = output.sum()
loss.backward()
assert numeric_values.grad is not None
assert all(p.grad is not None for p in model.parameters())
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import pytest
from attention_smithy.numeric_embeddings import RotaryCustomEmbedding

# --------------- FIXTURE -----------------

@pytest.fixture
def dummy_tensor():
batch, num_heads, seq_len, dim = 2, 4, 6, 64
return torch.randn(batch, num_heads, seq_len, dim)

@pytest.fixture
def custom_values():
return torch.tensor([0.4, 1.2, 5.5, 13.3, 21.0, 30.5], dtype=torch.float32)

@pytest.fixture
def custom_rotary():
return RotaryCustomEmbedding(head_dimension=64)

# --------------- TESTS -------------------

def test__RotaryCustomEmbedding__preserves_shape(dummy_tensor, custom_values, custom_rotary):
output = custom_rotary(dummy_tensor, rotary_custom_values=custom_values)
assert output.shape == dummy_tensor.shape

def test__RotaryCustomEmbedding__fails_on_mismatched_positions(dummy_tensor, custom_rotary):
wrong_positions = torch.tensor([0.1, 0.2]) # Too few positions
with pytest.raises(ValueError, match="rotary_custom_values length .* must match sequence length"):
custom_rotary(dummy_tensor, rotary_custom_values=wrong_positions)

def test__RotaryCustomEmbedding__different_positions_give_different_results(dummy_tensor, custom_rotary):
seq_len = dummy_tensor.shape[-2]
positions_1 = torch.linspace(0.0, 1.0, steps=seq_len)
positions_2 = torch.linspace(10.0, 20.0, steps=seq_len)

out1 = custom_rotary(dummy_tensor, rotary_custom_values=positions_1)
out2 = custom_rotary(dummy_tensor, rotary_custom_values=positions_2)

assert not torch.allclose(out1, out2, atol=1e-5), "Outputs should differ for different custom positions"

def test__RotaryCustomEmbedding__no_nan_or_inf(dummy_tensor, custom_values, custom_rotary):
output = custom_rotary(dummy_tensor, rotary_custom_values=custom_values)
assert not torch.isnan(output).any(), "Output contains NaNs"
assert not torch.isinf(output).any(), "Output contains Infs"