diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index fca566591a..4e8a6a0463 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -12,6 +12,7 @@ from __future__ import annotations from collections.abc import Sequence +from typing import Optional import numpy as np import torch @@ -19,14 +20,14 @@ import torch.nn.functional as F from torch.nn import LayerNorm -from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding +from monai.networks.blocks.pos_embed_utils import build_fourier_position_embedding, build_sincos_position_embedding from monai.networks.layers import Conv, trunc_normal_ from monai.utils import ensure_tuple_rep, optional_import from monai.utils.module import look_up_option Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"} -SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"} +SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"} class PatchEmbeddingBlock(nn.Module): @@ -53,6 +54,7 @@ def __init__( pos_embed_type: str = "learnable", dropout_rate: float = 0.0, spatial_dims: int = 3, + pos_embed_kwargs: Optional[dict] = None, ) -> None: """ Args: @@ -65,6 +67,8 @@ def __init__( pos_embed_type: position embedding layer type. dropout_rate: fraction of the input units to drop. spatial_dims: number of spatial dimensions. + pos_embed_kwargs: additional arguments for position embedding. For `sincos`, it can contain + `temperature` and for fourier it can contain `scales`. """ super().__init__() @@ -105,6 +109,8 @@ def __init__( self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) + pos_embed_kwargs = {} if pos_embed_kwargs is None else pos_embed_kwargs + if self.pos_embed_type == "none": pass elif self.pos_embed_type == "learnable": @@ -114,7 +120,17 @@ def __init__( for in_size, pa_size in zip(img_size, patch_size): grid_size.append(in_size // pa_size) - self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) + self.position_embeddings = build_sincos_position_embedding( + grid_size, hidden_size, spatial_dims, **pos_embed_kwargs + ) + elif self.pos_embed_type == "fourier": + grid_size = [] + for in_size, pa_size in zip(img_size, patch_size): + grid_size.append(in_size // pa_size) + + self.position_embeddings = build_fourier_position_embedding( + grid_size, hidden_size, spatial_dims, **pos_embed_kwargs + ) else: raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.") diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index a9c5176bc2..266be5e28c 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -__all__ = ["build_sincos_position_embedding"] +__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"] # From PyTorch internals @@ -32,6 +32,59 @@ def parse(x): return parse +def build_fourier_position_embedding( + grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0 +) -> torch.nn.Parameter: + """ + Builds a (Anistropic) Fourier feature position embedding based on the given grid size, embed dimension, + spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant + points more distinguishable. + Position embedding is made anistropic by allowing setting different scales for each spatial dimension. + Reference: https://arxiv.org/abs/2509.02488 + + Args: + grid_size (int | List[int]): The size of the grid in each spatial dimension. + embed_dim (int): The dimension of the embedding. + spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D). + scales (float | List[float]): The scale for every spatial dimension. If a single float is provided, + the same scale is used for all dimensions. + + Returns: + pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter. + """ + + to_tuple = _ntuple(spatial_dims) + grid_size_t = to_tuple(grid_size) + if len(grid_size_t) != spatial_dims: + raise ValueError(f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims.") + + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even for Fourier position embedding") + + # Ensure scales is a tensor of shape (spatial_dims,) + if isinstance(scales, float): + scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float) + elif isinstance(scales, (list, tuple)): + if len(scales) != spatial_dims: + raise ValueError(f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}") + scales_tensor = torch.tensor(scales, dtype=torch.float) + else: + raise TypeError(f"scales must be float or list of floats, got {type(scales)}") + + gaussians = torch.randn(embed_dim // 2, spatial_dims, dtype=torch.float32) * scales_tensor + + position_indices = [torch.linspace(0, 1, x, dtype=torch.float32) for x in grid_size_t] + positions = torch.stack(torch.meshgrid(*position_indices, indexing="ij"), dim=-1) + positions = positions.flatten(end_dim=-2) + + x_proj = (2.0 * torch.pi * positions) @ gaussians.T + + pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + pos_emb = nn.Parameter(pos_emb[None, :, :], requires_grad=False) + + return pos_emb + + def build_sincos_position_embedding( grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0 ) -> torch.nn.Parameter: diff --git a/tests/networks/blocks/test_patchembedding.py b/tests/networks/blocks/test_patchembedding.py index 2945482649..95eba14e6f 100644 --- a/tests/networks/blocks/test_patchembedding.py +++ b/tests/networks/blocks/test_patchembedding.py @@ -87,6 +87,19 @@ def test_sincos_pos_embed(self): self.assertEqual(net.position_embeddings.requires_grad, False) + def test_fourier_pos_embed(self): + net = PatchEmbeddingBlock( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(8, 8, 8), + hidden_size=96, + num_heads=8, + pos_embed_type="fourier", + dropout_rate=0.5, + ) + + self.assertEqual(net.position_embeddings.requires_grad, False) + def test_learnable_pos_embed(self): net = PatchEmbeddingBlock( in_channels=1, @@ -101,6 +114,32 @@ def test_learnable_pos_embed(self): self.assertEqual(net.position_embeddings.requires_grad, True) def test_ill_arg(self): + with self.assertRaises(ValueError): + PatchEmbeddingBlock( + in_channels=1, + img_size=(128, 128, 128), + patch_size=(16, 16, 16), + hidden_size=128, + num_heads=12, + proj_type="conv", + dropout_rate=0.1, + pos_embed_type="fourier", + pos_embed_kwargs=dict(scales=[1.0, 1.0]), + ) + + with self.assertRaises(ValueError): + PatchEmbeddingBlock( + in_channels=1, + img_size=(128, 128), + patch_size=(16, 16), + hidden_size=128, + num_heads=12, + proj_type="conv", + dropout_rate=0.1, + pos_embed_type="fourier", + pos_embed_kwargs=dict(scales=[1.0, 1.0, 1.0]), + ) + with self.assertRaises(ValueError): PatchEmbeddingBlock( in_channels=1,