Skip to content
This repository was archived by the owner on Apr 15, 2026. It is now read-only.

teamtomo/torch-fourier-slice

Repository files navigation

Warning

This package has been migrated to the TeamTomo monorepo. Future development, bug fixes, and releases will happen there. This repository is archived and no longer maintained. This package is still published to and installable from the same PyPI project, but development installations should be made from the monorepo.

torch-fourier-slice

License PyPI Python Version CI codecov

Fourier slice extraction/insertion from 2D images and 3D volumes in PyTorch.

Overview

This package provides a simple API for back projection (reconstruction) and forward projection of 3D volumes using Fourier slice insertion and extraction. This can be done for

  • single volumes with project_3d_to_2d() and backproject_2d_to_3d()
  • and multichannel volumes with project_3d_to_2d_multichannel() and backproject_2d_to_3d_multichannel()

There are also some lower order layers in the package that run directly on Fourier transforms of volumes/images which can be relevant if the fourier transform can be precalculated:

  • extract_central_slices_rfft_3d(), extract_central_slices_rfft_3d_multichannel()
  • insert_central_slices_rfft_3d(), insert_central_slices_rfft_3d_multichannel()

The package also provides a use case for extracting common lines from 2D images with project_2d_to_1d which can be useful for tilt-axis angle optimization in cryo-ET.

Installation

pip install torch-fourier-slice

Usage

Single volume

import torch
from scipy.stats import special_ortho_group
from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d

# start with a volume
volume = torch.rand((30, 30, 30))

# and some random rotations
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=10))
# shape is (10, 3, 3)

# forward project the volume, provides 10 projection images
projections = project_3d_to_2d(volume, rotation_matrices)
# shape is (10, 30, 30)

# we can backproject the 10 images to get the original volume back
reconstruction = backproject_2d_to_3d(projections, rotation_matrices)
# shape is (30, 30, 30)

# we can have an arbitrary number of leading dimensions for the rotations
rotation_matrices = torch.rand(3, 10, 3, 3)
projections = project_3d_to_2d(volume, rotation_matrices)
# shape is (3, 10, 30, 30)

# but for reconstruction it needs to match up with the projections
reconstruction = backproject_2d_to_3d(
    projections,  # (3, 10, 30, 30) 
    rotation_matrices  # (3, 10, 3, 3)
)
# shape is (30, 30, 30

Multichannel volumes

import torch
from scipy.stats import special_ortho_group
from torch_fourier_slice import project_3d_to_2d_multichannel, backproject_2d_to_3d_multichannel

# now we start with a multichannel 3d volume
volume = torch.rand((5, 30, 30, 30))

# and some random rotations
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=10))
# shape is (10, 3, 3)

# forward project the volume, provides 10 projection images with 5 channels each
projections = project_3d_to_2d_multichannel(volume, rotation_matrices)
# shape is (10, 5, 30, 30)

# we can backproject the 10 multichannel images to get the original multichannel volume back
reconstruction = backproject_2d_to_3d_multichannel(projections, rotation_matrices)
# shape is (5, 30, 30, 30)

# we can have an arbitrary number of trailing dimensions as well for multichannel data
rotation_matrices = torch.rand(3, 10, 3, 3)
projections = project_3d_to_2d_multichannel(volume, rotation_matrices)
# shape is (3, 10, 5, 30, 30)

# but for reconstruction it needs to match up with the projections
reconstruction = backproject_2d_to_3d_multichannel(
    projections,  # (3, 10, 5, 30, 30) 
    rotation_matrices  # (3, 10, 3, 3)
)
# shape is (5, 30, 30, 30)

License

This project is licensed under the BSD 3-Clause License - see the LICENSE file for details.

About

Fourier slice extraction/insertion in PyTorch

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages