|
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import torch |
8 | | -from torch_ctf import calculate_ctf_2d |
| 8 | +from torch_ctf import calc_LPP_ctf_2D, calculate_ctf_2d |
9 | 9 | from torch_fourier_filter.envelopes import b_envelope |
10 | 10 |
|
11 | 11 | from leopard_em.utils.search_utils import get_cs_range |
@@ -78,25 +78,56 @@ def calculate_ctf_filter_stack_full_args( |
78 | 78 | # If it's neither a list nor a tensor, try to convert it |
79 | 79 | mag_matrix = torch.tensor(mag_matrix, dtype=torch.float32) |
80 | 80 |
|
| 81 | + # When laser phase plate params are provided, use LPP CTF; otherwise standard CTF |
| 82 | + laser_params = kwargs.pop("laser_params", None) |
| 83 | + |
81 | 84 | # Loop over spherical aberrations one at a time and collect results |
82 | 85 | ctf_list = [] |
83 | 86 | for cs_val in cs_values: |
84 | | - tmp = calculate_ctf_2d( |
85 | | - defocus=defocus * 1e-4, # Convert to um from Angstrom |
86 | | - astigmatism=astigmatism * 1e-4, # Convert to um from Angstrom |
87 | | - astigmatism_angle=kwargs["astigmatism_angle"], |
88 | | - voltage=kwargs["voltage"], |
89 | | - spherical_aberration=cs_val, |
90 | | - amplitude_contrast=kwargs["amplitude_contrast_ratio"], |
91 | | - phase_shift=kwargs["phase_shift"], |
92 | | - pixel_size=kwargs["pixel_size"], |
93 | | - image_shape=template_shape, |
94 | | - rfft=True, |
95 | | - fftshift=False, |
96 | | - even_zernike_coeffs=kwargs["even_zernikes"], |
97 | | - odd_zernike_coeffs=kwargs["odd_zernikes"], |
98 | | - transform_matrix=mag_matrix, |
99 | | - ) |
| 87 | + if laser_params is not None: |
| 88 | + tmp = calc_LPP_ctf_2D( |
| 89 | + defocus=defocus * 1e-4, # Convert to um from Angstrom |
| 90 | + astigmatism=astigmatism * 1e-4, # Convert to um from Angstrom |
| 91 | + astigmatism_angle=kwargs["astigmatism_angle"], |
| 92 | + voltage=kwargs["voltage"], |
| 93 | + spherical_aberration=cs_val, |
| 94 | + amplitude_contrast=kwargs["amplitude_contrast_ratio"], |
| 95 | + pixel_size=kwargs["pixel_size"], |
| 96 | + image_shape=template_shape, |
| 97 | + rfft=True, |
| 98 | + fftshift=False, |
| 99 | + NA=laser_params.NA, |
| 100 | + laser_wavelength_angstrom=laser_params.laser_wavelength_angstrom, |
| 101 | + focal_length_angstrom=laser_params.focal_length_angstrom, |
| 102 | + laser_xy_angle_deg=laser_params.laser_xy_angle_deg, |
| 103 | + laser_xz_angle_deg=laser_params.laser_xz_angle_deg, |
| 104 | + laser_long_offset_angstrom=laser_params.laser_long_offset_angstrom, |
| 105 | + laser_trans_offset_angstrom=laser_params.laser_trans_offset_angstrom, |
| 106 | + laser_polarization_angle_deg=laser_params.laser_polarization_angle_deg, |
| 107 | + peak_phase_deg=laser_params.peak_phase_deg, |
| 108 | + dual_laser=laser_params.dual_laser, |
| 109 | + beam_tilt_mrad=None, |
| 110 | + even_zernike_coeffs=kwargs.get("even_zernikes"), |
| 111 | + odd_zernike_coeffs=kwargs.get("odd_zernikes"), |
| 112 | + transform_matrix=mag_matrix, |
| 113 | + ) |
| 114 | + else: |
| 115 | + tmp = calculate_ctf_2d( |
| 116 | + defocus=defocus * 1e-4, # Convert to um from Angstrom |
| 117 | + astigmatism=astigmatism * 1e-4, # Convert to um from Angstrom |
| 118 | + astigmatism_angle=kwargs["astigmatism_angle"], |
| 119 | + voltage=kwargs["voltage"], |
| 120 | + spherical_aberration=cs_val, |
| 121 | + amplitude_contrast=kwargs["amplitude_contrast_ratio"], |
| 122 | + phase_shift=kwargs["phase_shift"], |
| 123 | + pixel_size=kwargs["pixel_size"], |
| 124 | + image_shape=template_shape, |
| 125 | + rfft=True, |
| 126 | + fftshift=False, |
| 127 | + even_zernike_coeffs=kwargs["even_zernikes"], |
| 128 | + odd_zernike_coeffs=kwargs["odd_zernikes"], |
| 129 | + transform_matrix=mag_matrix, |
| 130 | + ) |
100 | 131 | # calc B-envelope and apply |
101 | 132 | b_envelope_tmp = b_envelope( |
102 | 133 | B=kwargs["ctf_B_factor"], |
@@ -139,22 +170,27 @@ def calculate_ctf_filter_stack( |
139 | 170 | Tensor of CTF filter values for the specified shape and optics group. Will have |
140 | 171 | shape (num_pixel_sizes, num_defocus_offsets, h, w // 2 + 1) |
141 | 172 | """ |
| 173 | + kwargs: dict[str, Any] = { |
| 174 | + "astigmatism_angle": optics_group.astigmatism_angle, |
| 175 | + "voltage": optics_group.voltage, |
| 176 | + "spherical_aberration": optics_group.spherical_aberration, |
| 177 | + "amplitude_contrast_ratio": optics_group.amplitude_contrast_ratio, |
| 178 | + "ctf_B_factor": optics_group.ctf_B_factor, |
| 179 | + "phase_shift": optics_group.phase_shift, |
| 180 | + "pixel_size": optics_group.pixel_size, |
| 181 | + "even_zernikes": optics_group.even_zernikes, |
| 182 | + "odd_zernikes": optics_group.odd_zernikes, |
| 183 | + "mag_matrix": optics_group.mag_matrix_tensor, |
| 184 | + } |
| 185 | + if optics_group.laser_params is not None: |
| 186 | + kwargs["laser_params"] = optics_group.laser_params |
142 | 187 | return calculate_ctf_filter_stack_full_args( |
143 | 188 | template_shape, |
144 | 189 | optics_group.defocus_u, |
145 | 190 | optics_group.defocus_v, |
146 | 191 | defocus_offsets, |
147 | 192 | pixel_size_offsets, |
148 | | - astigmatism_angle=optics_group.astigmatism_angle, |
149 | | - voltage=optics_group.voltage, |
150 | | - spherical_aberration=optics_group.spherical_aberration, |
151 | | - amplitude_contrast_ratio=optics_group.amplitude_contrast_ratio, |
152 | | - ctf_B_factor=optics_group.ctf_B_factor, |
153 | | - phase_shift=optics_group.phase_shift, |
154 | | - pixel_size=optics_group.pixel_size, |
155 | | - even_zernikes=optics_group.even_zernikes, |
156 | | - odd_zernikes=optics_group.odd_zernikes, |
157 | | - mag_matrix=optics_group.mag_matrix_tensor, |
| 193 | + **kwargs, |
158 | 194 | ) |
159 | 195 |
|
160 | 196 |
|
|
0 commit comments