diff --git a/pyproject.toml b/pyproject.toml index dc18c86..7d7051d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "attention_smithy" -version = "1.1.1" +version = "1.2.0" authors = [ { name="Caleb Cranney", email="11773171+CCranney@users.noreply.github.com" }, ] diff --git a/src/attention_smithy/numeric_embeddings/__init__.py b/src/attention_smithy/numeric_embeddings/__init__.py index 80d50e4..93091a6 100644 --- a/src/attention_smithy/numeric_embeddings/__init__.py +++ b/src/attention_smithy/numeric_embeddings/__init__.py @@ -6,4 +6,5 @@ from attention_smithy.numeric_embeddings.specialized.SinusoidalCustomEmbedding import SinusoidalCustomEmbedding from attention_smithy.numeric_embeddings.specialized.ALiBiCustomEmbedding import ALiBiCustomEmbedding - +from attention_smithy.numeric_embeddings.specialized.RotaryCustomEmbedding import RotaryCustomEmbedding +from attention_smithy.numeric_embeddings.specialized.ContinuousValueEmbedding import ContinuousValueEmbedding diff --git a/src/attention_smithy/numeric_embeddings/abstract_embedding_strategies.py b/src/attention_smithy/numeric_embeddings/abstract_embedding_strategies.py index 9f02ef3..5f08450 100644 --- a/src/attention_smithy/numeric_embeddings/abstract_embedding_strategies.py +++ b/src/attention_smithy/numeric_embeddings/abstract_embedding_strategies.py @@ -19,8 +19,6 @@ def create_positional_or_custom_embedding(self, **kwargs) -> torch.Tensor: """ pass -from abc import ABC, abstractmethod - class MatrixModificationStrategyBase(ABC, nn.Module): """ Abstract base class for strategies that modify or transform matrices. @@ -38,8 +36,6 @@ def modify_matrix(self, target_matrix, **kwargs) -> torch.Tensor: """ pass -from abc import ABC, abstractmethod - class AttentionBiasStrategyBase(ABC, nn.Module): """ Abstract base class for strategies that generate bias tensors to be added to attention score matrices. diff --git a/src/attention_smithy/numeric_embeddings/specialized/ContinuousValueEmbedding.py b/src/attention_smithy/numeric_embeddings/specialized/ContinuousValueEmbedding.py new file mode 100644 index 0000000..ff6afc7 --- /dev/null +++ b/src/attention_smithy/numeric_embeddings/specialized/ContinuousValueEmbedding.py @@ -0,0 +1,29 @@ +import torch +from torch import nn +from attention_smithy.numeric_embeddings.abstract_embedding_strategies import NumericEmbeddingStrategyBase + +class ContinuousValueEmbedding(NumericEmbeddingStrategyBase): + """ + Encodes continuous scalar values into a high-dimensional embedding space using + a two-layer MLP with a tanh non-linearity. + This approach allows numeric values (e.g., scalar features or non-position numbers) + to be encoded in a learnable, non-linear way. + + See https://doi.org/10.1145/3516367. + """ + + def __init__(self, embedding_dimension: int) -> None: + super().__init__() + hidden_dim = int(embedding_dimension ** 0.5) + self.net = nn.Sequential( + nn.Linear(1, hidden_dim), + nn.Tanh(), + nn.Linear(hidden_dim, embedding_dimension, bias=False) + ) + + def forward(self, continuous_values: torch.Tensor, **kwargs) -> torch.Tensor: + continuous_values = continuous_values.unsqueeze(-1) + return self.net(continuous_values) + + def create_positional_or_custom_embedding(self, continuous_values: torch.Tensor, **kwargs) -> torch.Tensor: + return self(continuous_values, **kwargs) \ No newline at end of file diff --git a/src/attention_smithy/numeric_embeddings/specialized/RotaryCustomEmbedding.py b/src/attention_smithy/numeric_embeddings/specialized/RotaryCustomEmbedding.py new file mode 100644 index 0000000..bea9df7 --- /dev/null +++ b/src/attention_smithy/numeric_embeddings/specialized/RotaryCustomEmbedding.py @@ -0,0 +1,53 @@ +import torch +from torch import nn +from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb +from einops import rearrange +from typing import Optional +from attention_smithy.numeric_embeddings.abstract_embedding_strategies import MatrixModificationStrategyBase + +class RotaryCustomEmbedding(MatrixModificationStrategyBase): + """ + Rotary Embedding wrapper that allows applying custom scalar position values + (e.g., [0.4, 2.6, 10.1]) instead of standard integer positions (0, 1, 2, ...). + + Applies rotary embeddings to input tensors like queries or keys using those custom positions. + """ + + def __init__(self, head_dimension: int): + """ + Args: + head_dimension (int): Embedding dimension per attention head (must be divisible by 2). + """ + super().__init__() + self.rotary = RotaryEmbedding(dim=head_dimension, cache_if_possible=False) + + def modify_matrix(self, target_matrix: torch.Tensor, **kwargs) -> torch.Tensor: + return self.forward(target_matrix, **kwargs) + + def forward( + self, + target_matrix: torch.Tensor, + rotary_custom_values: torch.Tensor, + seq_dim: Optional[int] = None + ) -> torch.Tensor: + """ + Args: + target_matrix (torch.Tensor): The input Q or K matrix of shape (..., seq_len, dim). + rotary_custom_values (torch.Tensor): Float tensor of shape (seq_len,) specifying custom position values. + seq_dim (int, optional): The dimension in `tensor` corresponding to sequence length. Defaults to -2. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + seq_dim = seq_dim if seq_dim is not None else -2 + seq_len = target_matrix.shape[seq_dim] + + if rotary_custom_values.shape[0] != seq_len: + raise ValueError(f"rotary_custom_values length {rotary_custom_values.shape[0]} must match sequence length {seq_len}") + + freqs = self.rotary.forward(rotary_custom_values, seq_len=seq_len) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + + return apply_rotary_emb(freqs, target_matrix, seq_dim=seq_dim) \ No newline at end of file diff --git a/src/attention_smithy/numeric_embeddings/specialized/tests/test_ALiBiCustomEmbedding.py b/src/attention_smithy/numeric_embeddings/specialized/tests/test_ALiBiCustomEmbedding.py index c788983..0f80cba 100644 --- a/src/attention_smithy/numeric_embeddings/specialized/tests/test_ALiBiCustomEmbedding.py +++ b/src/attention_smithy/numeric_embeddings/specialized/tests/test_ALiBiCustomEmbedding.py @@ -36,11 +36,11 @@ def test__ALiBiCustomEmbedding__numHeads2(): [-1.0, 0.0, -2.0, -0.5, -1.5], ], [ - [-0.25, -0.25, -0.75, 0.00, -0.50], - [-0.25, -0.75, -0.25, -0.50, 0.00], - [-0.50, -1.00, 0.00, -0.75, -0.25], - [0.00, -0.50, -0.50, -0.25, -0.25], - [-0.50, 0.00, -1.00, -0.25, -0.75], + [-0.25, -0.25, -0.75, 0.00, -0.50], + [-0.25, -0.75, -0.25, -0.50, 0.00], + [-0.50, -1.00, 0.00, -0.75, -0.25], + [ 0.00, -0.50, -0.50, -0.25, -0.25], + [-0.50, 0.00, -1.00, -0.25, -0.75], ], ])[None, :, :, :] output = embedding(query_values, key_values, None) diff --git a/src/attention_smithy/numeric_embeddings/specialized/tests/test_ContinuousValueEmbedding.py b/src/attention_smithy/numeric_embeddings/specialized/tests/test_ContinuousValueEmbedding.py new file mode 100644 index 0000000..5e95788 --- /dev/null +++ b/src/attention_smithy/numeric_embeddings/specialized/tests/test_ContinuousValueEmbedding.py @@ -0,0 +1,38 @@ +import torch +import pytest +from attention_smithy.numeric_embeddings import ContinuousValueEmbedding + +@pytest.fixture +def model(): + return ContinuousValueEmbedding(embedding_dimension=64) + +def test__ContinuousValueEmbedding__output_shape_matches_expected(model): + numeric_values = torch.randn(8, 10) # batch_size=8, seq_len=10 + output = model(numeric_values) + assert output.shape == (8, 10, 64) + +def test__ContinuousValueEmbedding__runs_on_different_input_shapes(model): + for bsz, seq_len in [(1, 1), (4, 7), (16, 32)]: + numeric_values = torch.randn(bsz, seq_len) + output = model(numeric_values) + assert output.shape == (bsz, seq_len, 64) + +def test__ContinuousValueEmbedding__has_expected_number_of_parameters(): + model = ContinuousValueEmbedding(embedding_dimension=64) + params = list(model.parameters()) + assert len(params) == 3 + assert model.net[2].bias is None + +def test__ContinuousValueEmbedding__create_embedding_matches_forward(model): + numeric_values = torch.randn(3, 5) + output_from_forward = model(numeric_values) + output_from_create = model.create_positional_or_custom_embedding(numeric_values) + assert torch.allclose(output_from_forward, output_from_create) + +def test__ContinuousValueEmbedding__gradient_flow(model): + numeric_values = torch.randn(2, 4, requires_grad=True) + output = model(numeric_values) + loss = output.sum() + loss.backward() + assert numeric_values.grad is not None + assert all(p.grad is not None for p in model.parameters()) \ No newline at end of file diff --git a/src/attention_smithy/numeric_embeddings/specialized/tests/test_RotaryCustomEmbedding.py b/src/attention_smithy/numeric_embeddings/specialized/tests/test_RotaryCustomEmbedding.py new file mode 100644 index 0000000..0b67ded --- /dev/null +++ b/src/attention_smithy/numeric_embeddings/specialized/tests/test_RotaryCustomEmbedding.py @@ -0,0 +1,44 @@ +import torch +import pytest +from attention_smithy.numeric_embeddings import RotaryCustomEmbedding + +# --------------- FIXTURE ----------------- + +@pytest.fixture +def dummy_tensor(): + batch, num_heads, seq_len, dim = 2, 4, 6, 64 + return torch.randn(batch, num_heads, seq_len, dim) + +@pytest.fixture +def custom_values(): + return torch.tensor([0.4, 1.2, 5.5, 13.3, 21.0, 30.5], dtype=torch.float32) + +@pytest.fixture +def custom_rotary(): + return RotaryCustomEmbedding(head_dimension=64) + +# --------------- TESTS ------------------- + +def test__RotaryCustomEmbedding__preserves_shape(dummy_tensor, custom_values, custom_rotary): + output = custom_rotary(dummy_tensor, rotary_custom_values=custom_values) + assert output.shape == dummy_tensor.shape + +def test__RotaryCustomEmbedding__fails_on_mismatched_positions(dummy_tensor, custom_rotary): + wrong_positions = torch.tensor([0.1, 0.2]) # Too few positions + with pytest.raises(ValueError, match="rotary_custom_values length .* must match sequence length"): + custom_rotary(dummy_tensor, rotary_custom_values=wrong_positions) + +def test__RotaryCustomEmbedding__different_positions_give_different_results(dummy_tensor, custom_rotary): + seq_len = dummy_tensor.shape[-2] + positions_1 = torch.linspace(0.0, 1.0, steps=seq_len) + positions_2 = torch.linspace(10.0, 20.0, steps=seq_len) + + out1 = custom_rotary(dummy_tensor, rotary_custom_values=positions_1) + out2 = custom_rotary(dummy_tensor, rotary_custom_values=positions_2) + + assert not torch.allclose(out1, out2, atol=1e-5), "Outputs should differ for different custom positions" + +def test__RotaryCustomEmbedding__no_nan_or_inf(dummy_tensor, custom_values, custom_rotary): + output = custom_rotary(dummy_tensor, rotary_custom_values=custom_values) + assert not torch.isnan(output).any(), "Output contains NaNs" + assert not torch.isinf(output).any(), "Output contains Infs" \ No newline at end of file