From 93c935462d0ada7bb044a278f8e93ace006521f9 Mon Sep 17 00:00:00 2001 From: NabJa Date: Mon, 15 Sep 2025 11:59:32 +0200 Subject: [PATCH 01/12] Add Fourier feature positional encoding (#8564) --- monai/networks/blocks/patchembedding.py | 19 ++++++-- monai/networks/blocks/pos_embed_utils.py | 46 +++++++++++++++++++- tests/networks/blocks/test_patchembedding.py | 38 ++++++++++++++++ 3 files changed, 99 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index fca566591a..53d626a622 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -19,14 +19,14 @@ import torch.nn.functional as F from torch.nn import LayerNorm -from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding +from monai.networks.blocks.pos_embed_utils import build_fourier_position_embedding, build_sincos_position_embedding from monai.networks.layers import Conv, trunc_normal_ from monai.utils import ensure_tuple_rep, optional_import from monai.utils.module import look_up_option Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"} -SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"} +SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"} class PatchEmbeddingBlock(nn.Module): @@ -53,6 +53,7 @@ def __init__( pos_embed_type: str = "learnable", dropout_rate: float = 0.0, spatial_dims: int = 3, + pos_embed_kwargs: dict = {}, ) -> None: """ Args: @@ -65,6 +66,8 @@ def __init__( pos_embed_type: position embedding layer type. dropout_rate: fraction of the input units to drop. spatial_dims: number of spatial dimensions. + pos_embed_kwargs: additional arguments for position embedding. For `sincos`, it can contain + `temperature` and for fourier it can contain `scales`. """ super().__init__() @@ -114,7 +117,17 @@ def __init__( for in_size, pa_size in zip(img_size, patch_size): grid_size.append(in_size // pa_size) - self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) + self.position_embeddings = build_sincos_position_embedding( + grid_size, hidden_size, spatial_dims, **pos_embed_kwargs + ) + elif self.pos_embed_type == "fourier": + grid_size = [] + for in_size, pa_size in zip(img_size, patch_size): + grid_size.append(in_size // pa_size) + + self.position_embeddings = build_fourier_position_embedding( + grid_size, hidden_size, spatial_dims, **pos_embed_kwargs + ) else: raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.") diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index a9c5176bc2..b7e4ff9b82 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -__all__ = ["build_sincos_position_embedding"] +__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"] # From PyTorch internals @@ -32,6 +32,50 @@ def parse(x): return parse +def build_fourier_position_embedding( + grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0 +): + """ + Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension, + spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant + points more distinguishable. + Reference: https://arxiv.org/abs/2509.02488 + + Args: + grid_size (List[int]): The size of the grid in each spatial dimension. + embed_dim (int): The dimension of the embedding. + spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D). + scales (List[float]): The scale for every spatial dimension. If a single float is provided, + the same scale is used for all dimensions. + + Returns: + pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter. + """ + + to_tuple = _ntuple(spatial_dims) + grid_size = to_tuple(grid_size) + + scales = torch.tensor(scales) + if scales.ndim > 1 and scales.ndim != spatial_dims: + raise ValueError("Scales must be either a float or a list of floats with length equal to spatial_dims") + if scales.ndim == 0: + scales = scales.repeat(spatial_dims) + + gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims)) + gaussians = gaussians * scales + + positions = [torch.linspace(0, 1, x) for x in grid_size] + positions = torch.stack(torch.meshgrid(*positions, indexing="ij"), axis=-1) + positions = positions.flatten(end_dim=-2) + + x_proj = (2.0 * torch.pi * positions) @ gaussians.T + + pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], axis=-1) + pos_emb = pos_emb[None, :, :] + + return nn.Parameter(pos_emb, requires_grad=False) + + def build_sincos_position_embedding( grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0 ) -> torch.nn.Parameter: diff --git a/tests/networks/blocks/test_patchembedding.py b/tests/networks/blocks/test_patchembedding.py index 2945482649..2d6de2f4dd 100644 --- a/tests/networks/blocks/test_patchembedding.py +++ b/tests/networks/blocks/test_patchembedding.py @@ -87,6 +87,19 @@ def test_sincos_pos_embed(self): self.assertEqual(net.position_embeddings.requires_grad, False) + def test_fourier_pos_embed(self): + net = PatchEmbeddingBlock( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(8, 8, 8), + hidden_size=96, + num_heads=8, + pos_embed_type="fourier", + dropout_rate=0.5, + ) + + self.assertEqual(net.position_embeddings.requires_grad, False) + def test_learnable_pos_embed(self): net = PatchEmbeddingBlock( in_channels=1, @@ -101,6 +114,31 @@ def test_learnable_pos_embed(self): self.assertEqual(net.position_embeddings.requires_grad, True) def test_ill_arg(self): + with self.assertRaises(ValueError): + PatchEmbeddingBlock( + in_channels=1, + img_size=(128, 128, 128), + patch_size=(16, 16, 16), + hidden_size=128, + num_heads=12, + proj_type="conv", + dropout_rate=5.0, + pos_embed_type="fourier", + pos_embed_kwargs=dict(scales=[1.0, 1.0]), + ) + + PatchEmbeddingBlock( + in_channels=1, + img_size=(128, 128), + patch_size=(16, 16), + hidden_size=128, + num_heads=12, + proj_type="conv", + dropout_rate=5.0, + pos_embed_type="fourier", + pos_embed_kwargs=dict(scales=[1.0, 1.0, 1.0]), + ) + with self.assertRaises(ValueError): PatchEmbeddingBlock( in_channels=1, From 2794dec1f6758eb68d0ca5d19a48e43cd4c755d5 Mon Sep 17 00:00:00 2001 From: NabJa Date: Mon, 15 Sep 2025 12:18:15 +0200 Subject: [PATCH 02/12] Remove mutable default argument --- monai/networks/blocks/patchembedding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 53d626a622..4e8a6a0463 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -12,6 +12,7 @@ from __future__ import annotations from collections.abc import Sequence +from typing import Optional import numpy as np import torch @@ -53,7 +54,7 @@ def __init__( pos_embed_type: str = "learnable", dropout_rate: float = 0.0, spatial_dims: int = 3, - pos_embed_kwargs: dict = {}, + pos_embed_kwargs: Optional[dict] = None, ) -> None: """ Args: @@ -108,6 +109,8 @@ def __init__( self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) + pos_embed_kwargs = {} if pos_embed_kwargs is None else pos_embed_kwargs + if self.pos_embed_type == "none": pass elif self.pos_embed_type == "learnable": From 984bbe403489d9754369f193f2bb3cad24a801e1 Mon Sep 17 00:00:00 2001 From: NabJa Date: Mon, 15 Sep 2025 12:40:15 +0200 Subject: [PATCH 03/12] NabJa DCO Remediation Commit for NabJa I, NabJa , hereby add my Signed-off-by to this commit: 93c935462d0ada7bb044a278f8e93ace006521f9 I, NabJa , hereby add my Signed-off-by to this commit: 2794dec1f6758eb68d0ca5d19a48e43cd4c755d5 Signed-off-by: NabJa From bfb55920ea97a1bbc639e1c9bd4711ca22b93f6e Mon Sep 17 00:00:00 2001 From: NabJa Date: Mon, 15 Sep 2025 13:34:22 +0200 Subject: [PATCH 04/12] Add embed dim assertion --- monai/networks/blocks/pos_embed_utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index b7e4ff9b82..b7c0f9fc1b 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -34,7 +34,7 @@ def parse(x): def build_fourier_position_embedding( grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0 -): +) -> torch.nn.Parameter: """ Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension, spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant @@ -55,7 +55,12 @@ def build_fourier_position_embedding( to_tuple = _ntuple(spatial_dims) grid_size = to_tuple(grid_size) - scales = torch.tensor(scales) + if embed_dim % (2 * spatial_dims) != 0: + raise AssertionError( + f"Embed dimension must be divisible by {2 * spatial_dims} for {spatial_dims}D Fourier feature position embedding" + ) + + scales: torch.Tensor = torch.as_tensor(scales, dtype=torch.float) if scales.ndim > 1 and scales.ndim != spatial_dims: raise ValueError("Scales must be either a float or a list of floats with length equal to spatial_dims") if scales.ndim == 0: @@ -65,15 +70,15 @@ def build_fourier_position_embedding( gaussians = gaussians * scales positions = [torch.linspace(0, 1, x) for x in grid_size] - positions = torch.stack(torch.meshgrid(*positions, indexing="ij"), axis=-1) + positions = torch.stack(torch.meshgrid(*positions, indexing="ij"), dim=-1) positions = positions.flatten(end_dim=-2) x_proj = (2.0 * torch.pi * positions) @ gaussians.T - pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], axis=-1) - pos_emb = pos_emb[None, :, :] + pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + pos_emb = nn.Parameter(pos_emb[None, :, :], requires_grad=False) - return nn.Parameter(pos_emb, requires_grad=False) + return pos_emb def build_sincos_position_embedding( From 3f5b1255748f25bcfa063308f3e798ceb783fba5 Mon Sep 17 00:00:00 2001 From: NabJa Date: Mon, 15 Sep 2025 13:38:08 +0200 Subject: [PATCH 05/12] Correct PatchEmbeddingBlock test dropout --- tests/networks/blocks/test_patchembedding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/networks/blocks/test_patchembedding.py b/tests/networks/blocks/test_patchembedding.py index 2d6de2f4dd..95eba14e6f 100644 --- a/tests/networks/blocks/test_patchembedding.py +++ b/tests/networks/blocks/test_patchembedding.py @@ -122,11 +122,12 @@ def test_ill_arg(self): hidden_size=128, num_heads=12, proj_type="conv", - dropout_rate=5.0, + dropout_rate=0.1, pos_embed_type="fourier", pos_embed_kwargs=dict(scales=[1.0, 1.0]), ) + with self.assertRaises(ValueError): PatchEmbeddingBlock( in_channels=1, img_size=(128, 128), @@ -134,7 +135,7 @@ def test_ill_arg(self): hidden_size=128, num_heads=12, proj_type="conv", - dropout_rate=5.0, + dropout_rate=0.1, pos_embed_type="fourier", pos_embed_kwargs=dict(scales=[1.0, 1.0, 1.0]), ) From 6c885a11ac1258c798d9108797273d99461feb70 Mon Sep 17 00:00:00 2001 From: NabJa Date: Mon, 15 Sep 2025 14:41:06 +0200 Subject: [PATCH 06/12] DCO Remediation Commit for NabJa I, NabJa , hereby add my Signed-off-by to this commit: bfb55920ea97a1bbc639e1c9bd4711ca22b93f6e I, NabJa , hereby add my Signed-off-by to this commit: 3f5b1255748f25bcfa063308f3e798ceb783fba5 Signed-off-by: NabJa From 28aefc2f9dc620f514b8f44c4e556d2587490568 Mon Sep 17 00:00:00 2001 From: NabJa Date: Mon, 15 Sep 2025 14:42:54 +0200 Subject: [PATCH 07/12] Fix type and flake8 errors Signed-off-by: NabJa --- monai/networks/blocks/pos_embed_utils.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index b7c0f9fc1b..f124623a94 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -60,17 +60,23 @@ def build_fourier_position_embedding( f"Embed dimension must be divisible by {2 * spatial_dims} for {spatial_dims}D Fourier feature position embedding" ) - scales: torch.Tensor = torch.as_tensor(scales, dtype=torch.float) - if scales.ndim > 1 and scales.ndim != spatial_dims: - raise ValueError("Scales must be either a float or a list of floats with length equal to spatial_dims") - if scales.ndim == 0: - scales = scales.repeat(spatial_dims) + # Ensure scales is a tensor of shape (spatial_dims,) + if isinstance(scales, float): + scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float) + elif isinstance(scales, (list, tuple)): + if len(scales) != spatial_dims: + raise ValueError( + f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}" + ) + scales_tensor = torch.tensor(scales, dtype=torch.float) + else: + raise TypeError(f"scales must be float or list of floats, got {type(scales)}") gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims)) - gaussians = gaussians * scales + gaussians = gaussians * scales_tensor - positions = [torch.linspace(0, 1, x) for x in grid_size] - positions = torch.stack(torch.meshgrid(*positions, indexing="ij"), dim=-1) + position_indeces = [torch.linspace(0, 1, x) for x in grid_size] + positions = torch.stack(torch.meshgrid(*position_indeces, indexing="ij"), dim=-1) positions = positions.flatten(end_dim=-2) x_proj = (2.0 * torch.pi * positions) @ gaussians.T From 3ef218890fd005bb23521239a44f94d34155d763 Mon Sep 17 00:00:00 2001 From: NabJa Date: Mon, 15 Sep 2025 16:48:19 +0200 Subject: [PATCH 08/12] Code formatting Signed-off-by: NabJa --- monai/networks/blocks/pos_embed_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index f124623a94..ddc3f57fde 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -65,9 +65,7 @@ def build_fourier_position_embedding( scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float) elif isinstance(scales, (list, tuple)): if len(scales) != spatial_dims: - raise ValueError( - f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}" - ) + raise ValueError(f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}") scales_tensor = torch.tensor(scales, dtype=torch.float) else: raise TypeError(f"scales must be float or list of floats, got {type(scales)}") From 68584e5124f3ede475bf98da4d97b9b25ec1312d Mon Sep 17 00:00:00 2001 From: NabJa Date: Mon, 15 Sep 2025 17:09:54 +0200 Subject: [PATCH 09/12] Add grid_size check and fix typing Signed-off-by: NabJa --- monai/networks/blocks/pos_embed_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index ddc3f57fde..a0f67c86cf 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -53,7 +53,11 @@ def build_fourier_position_embedding( """ to_tuple = _ntuple(spatial_dims) - grid_size = to_tuple(grid_size) + grid_size_t = to_tuple(grid_size) + if len(grid_size_t) != spatial_dims: + raise ValueError( + f"Length of grid_size must be the same as spatial_dims. Got len(grid_size)={len(grid_size_t)}, should be {spatial_dims}." + ) if embed_dim % (2 * spatial_dims) != 0: raise AssertionError( @@ -73,7 +77,7 @@ def build_fourier_position_embedding( gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims)) gaussians = gaussians * scales_tensor - position_indeces = [torch.linspace(0, 1, x) for x in grid_size] + position_indeces = [torch.linspace(0, 1, x) for x in grid_size_t] positions = torch.stack(torch.meshgrid(*position_indeces, indexing="ij"), dim=-1) positions = positions.flatten(end_dim=-2) From 6d4c703c059e7d4fa5e681a9c4a42efe602be79f Mon Sep 17 00:00:00 2001 From: NabJa Date: Mon, 15 Sep 2025 17:28:15 +0200 Subject: [PATCH 10/12] Fix flake8 line too long Signed-off-by: NabJa --- monai/networks/blocks/pos_embed_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index a0f67c86cf..8e481e7480 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -56,7 +56,7 @@ def build_fourier_position_embedding( grid_size_t = to_tuple(grid_size) if len(grid_size_t) != spatial_dims: raise ValueError( - f"Length of grid_size must be the same as spatial_dims. Got len(grid_size)={len(grid_size_t)}, should be {spatial_dims}." + f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims." ) if embed_dim % (2 * spatial_dims) != 0: From aabf175aaf0e18e71e086412fb5b2a9db1407a55 Mon Sep 17 00:00:00 2001 From: NabJa Date: Mon, 15 Sep 2025 18:07:51 +0200 Subject: [PATCH 11/12] Formatting Signed-off-by: NabJa --- monai/networks/blocks/pos_embed_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index 8e481e7480..924ddf2381 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -55,9 +55,7 @@ def build_fourier_position_embedding( to_tuple = _ntuple(spatial_dims) grid_size_t = to_tuple(grid_size) if len(grid_size_t) != spatial_dims: - raise ValueError( - f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims." - ) + raise ValueError(f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims.") if embed_dim % (2 * spatial_dims) != 0: raise AssertionError( From 8a375b61e1fbd783940f2727445d9df372844344 Mon Sep 17 00:00:00 2001 From: NabJa Date: Tue, 16 Sep 2025 10:24:34 +0200 Subject: [PATCH 12/12] Fixed overrestrictive embed_dim check, improved code style Signed-off-by: NabJa --- monai/networks/blocks/pos_embed_utils.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index 924ddf2381..266be5e28c 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"] +__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"] # From PyTorch internals @@ -36,16 +36,17 @@ def build_fourier_position_embedding( grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0 ) -> torch.nn.Parameter: """ - Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension, + Builds a (Anistropic) Fourier feature position embedding based on the given grid size, embed dimension, spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant points more distinguishable. + Position embedding is made anistropic by allowing setting different scales for each spatial dimension. Reference: https://arxiv.org/abs/2509.02488 Args: - grid_size (List[int]): The size of the grid in each spatial dimension. + grid_size (int | List[int]): The size of the grid in each spatial dimension. embed_dim (int): The dimension of the embedding. spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D). - scales (List[float]): The scale for every spatial dimension. If a single float is provided, + scales (float | List[float]): The scale for every spatial dimension. If a single float is provided, the same scale is used for all dimensions. Returns: @@ -57,10 +58,8 @@ def build_fourier_position_embedding( if len(grid_size_t) != spatial_dims: raise ValueError(f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims.") - if embed_dim % (2 * spatial_dims) != 0: - raise AssertionError( - f"Embed dimension must be divisible by {2 * spatial_dims} for {spatial_dims}D Fourier feature position embedding" - ) + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even for Fourier position embedding") # Ensure scales is a tensor of shape (spatial_dims,) if isinstance(scales, float): @@ -72,11 +71,10 @@ def build_fourier_position_embedding( else: raise TypeError(f"scales must be float or list of floats, got {type(scales)}") - gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims)) - gaussians = gaussians * scales_tensor + gaussians = torch.randn(embed_dim // 2, spatial_dims, dtype=torch.float32) * scales_tensor - position_indeces = [torch.linspace(0, 1, x) for x in grid_size_t] - positions = torch.stack(torch.meshgrid(*position_indeces, indexing="ij"), dim=-1) + position_indices = [torch.linspace(0, 1, x, dtype=torch.float32) for x in grid_size_t] + positions = torch.stack(torch.meshgrid(*position_indices, indexing="ij"), dim=-1) positions = positions.flatten(end_dim=-2) x_proj = (2.0 * torch.pi * positions) @ gaussians.T