diff --git a/src/torch_ctf/ctf_lpp.py b/src/torch_ctf/ctf_lpp.py index 441632a..05eb370 100644 --- a/src/torch_ctf/ctf_lpp.py +++ b/src/torch_ctf/ctf_lpp.py @@ -454,6 +454,7 @@ def calc_LPP_ctf_2D( laser_trans_offset_angstrom: float, laser_polarization_angle_deg: float, peak_phase_deg: float, + dual_laser: bool = False, beam_tilt_mrad: torch.Tensor | None = None, even_zernike_coeffs: dict | None = None, odd_zernike_coeffs: dict | None = None, @@ -508,6 +509,10 @@ def calc_LPP_ctf_2D( Polarization angle of the laser in degrees. peak_phase_deg : float Desired peak phase in degrees. + dual_laser : bool, optional + If True, add a second laser with the same parameters but rotated 90° in the + xy plane (perpendicular to the first). The two phase contributions are summed. + Default is False. beam_tilt_mrad : torch.Tensor | None Beam tilt in milliradians. [bx, by] in mrad even_zernike_coeffs : dict | None @@ -580,19 +585,48 @@ def calc_LPP_ctf_2D( ) # Calculate laser phase using the dedicated function - laser_phase_radians = calc_LPP_phase( - fft_freq_grid=fft_freq_grid, - NA=NA, - laser_wavelength_angstrom=laser_wavelength_angstrom, - focal_length_angstrom=focal_length_angstrom, - laser_xy_angle_deg=laser_xy_angle_deg, - laser_xz_angle_deg=laser_xz_angle_deg, - laser_long_offset_angstrom=laser_long_offset_angstrom, - laser_trans_offset_angstrom=laser_trans_offset_angstrom, - laser_polarization_angle_deg=laser_polarization_angle_deg, - peak_phase_deg=peak_phase_deg, - voltage=voltage, - ) + if dual_laser: + phase1 = calc_LPP_phase( + fft_freq_grid=fft_freq_grid, + NA=NA, + laser_wavelength_angstrom=laser_wavelength_angstrom, + focal_length_angstrom=focal_length_angstrom, + laser_xy_angle_deg=laser_xy_angle_deg, + laser_xz_angle_deg=laser_xz_angle_deg, + laser_long_offset_angstrom=laser_long_offset_angstrom, + laser_trans_offset_angstrom=laser_trans_offset_angstrom, + laser_polarization_angle_deg=laser_polarization_angle_deg, + peak_phase_deg=peak_phase_deg, + voltage=voltage, + ) + phase2 = calc_LPP_phase( + fft_freq_grid=fft_freq_grid, + NA=NA, + laser_wavelength_angstrom=laser_wavelength_angstrom, + focal_length_angstrom=focal_length_angstrom, + laser_xy_angle_deg=laser_xy_angle_deg + 90, + laser_xz_angle_deg=laser_xz_angle_deg, + laser_long_offset_angstrom=laser_long_offset_angstrom, + laser_trans_offset_angstrom=laser_trans_offset_angstrom, + laser_polarization_angle_deg=laser_polarization_angle_deg, + peak_phase_deg=peak_phase_deg, + voltage=voltage, + ) + laser_phase_radians = phase1 + phase2 + else: + laser_phase_radians = calc_LPP_phase( + fft_freq_grid=fft_freq_grid, + NA=NA, + laser_wavelength_angstrom=laser_wavelength_angstrom, + focal_length_angstrom=focal_length_angstrom, + laser_xy_angle_deg=laser_xy_angle_deg, + laser_xz_angle_deg=laser_xz_angle_deg, + laser_long_offset_angstrom=laser_long_offset_angstrom, + laser_trans_offset_angstrom=laser_trans_offset_angstrom, + laser_polarization_angle_deg=laser_polarization_angle_deg, + peak_phase_deg=peak_phase_deg, + voltage=voltage, + ) # Convert laser phase from radians to degrees for compatibility laser_phase_degrees = torch.rad2deg(laser_phase_radians) diff --git a/tests/test_torch_ctf.py b/tests/test_torch_ctf.py index fbc2f3f..5219799 100644 --- a/tests/test_torch_ctf.py +++ b/tests/test_torch_ctf.py @@ -1587,6 +1587,42 @@ def test_calc_LPP_ctf_2D(): assert not torch.is_complex(result) +def test_calc_LPP_ctf_2D_dual_laser(): + """Test LPP CTF with dual perpendicular laser option.""" + common = { + "defocus": 1.5, + "astigmatism": 0, + "astigmatism_angle": 0, + "voltage": 300, + "spherical_aberration": 2.7, + "amplitude_contrast": 0.1, + "pixel_size": 8, + "image_shape": (10, 10), + "rfft": False, + "fftshift": False, + "NA": 0.1, + "laser_wavelength_angstrom": 5000.0, + "focal_length_angstrom": 1e6, + "laser_xy_angle_deg": 0.0, + "laser_xz_angle_deg": 0.0, + "laser_long_offset_angstrom": 0.0, + "laser_trans_offset_angstrom": 0.0, + "laser_polarization_angle_deg": 0.0, + "peak_phase_deg": 90.0, + } + result_single = calc_LPP_ctf_2D(**common, dual_laser=False) + result_dual = calc_LPP_ctf_2D(**common, dual_laser=True) + assert result_single.shape == (10, 10) + assert result_dual.shape == (10, 10) + assert torch.all(torch.isfinite(result_single)) + assert torch.all(torch.isfinite(result_dual)) + assert not torch.is_complex(result_single) + assert not torch.is_complex(result_dual) + assert not torch.allclose(result_single, result_dual), ( + "dual_laser=True should differ from dual_laser=False" + ) + + def test_calc_LPP_ctf_2D_with_zernikes(): """Test LPP CTF with Zernike coefficients.""" with pytest.warns(RuntimeWarning, match="Both beam tilt and Zernike"):