Skip to content

Commit a081497

Browse files
authored
* ENH: Add flux corrector to Lo processing
1 parent ce8cb30 commit a081497

2 files changed

Lines changed: 274 additions & 6 deletions

File tree

imap_processing/lo/l2/lo_l2.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from imap_processing.cdf.imap_cdf_manager import ImapCdfAttributes
1111
from imap_processing.ena_maps import ena_maps
1212
from imap_processing.ena_maps.ena_maps import AbstractSkyMap, RectangularSkyMap
13+
from imap_processing.ena_maps.utils.corrections import PowerLawFluxCorrector
1314
from imap_processing.ena_maps.utils.naming import MapDescriptor
1415
from imap_processing.lo import lo_ancillary
1516
from imap_processing.spice.time import et_to_datetime64, ttj2000ns_to_et
@@ -77,15 +78,23 @@ def lo_l2(
7778
logger.info("Step 4: Calculating rates and intensities")
7879

7980
# Determine if corrections are needed and prepare oxygen data if required
80-
sputtering_correction, bootstrap_correction, o_map_dataset = _prepare_corrections(
81+
(
82+
sputtering_correction,
83+
bootstrap_correction,
84+
flux_correction,
85+
o_map_dataset,
86+
flux_factors,
87+
) = _prepare_corrections(
8188
map_descriptor, descriptor, sci_dependencies, anc_dependencies
8289
)
8390

8491
dataset = calculate_all_rates_and_intensities(
8592
dataset,
8693
sputtering_correction=sputtering_correction,
8794
bootstrap_correction=bootstrap_correction,
95+
flux_correction=flux_correction,
8896
o_map_dataset=o_map_dataset,
97+
flux_factors=flux_factors,
8998
)
9099

91100
logger.info("Step 5: Finalizing dataset with attributes")
@@ -100,7 +109,7 @@ def _prepare_corrections(
100109
descriptor: str,
101110
sci_dependencies: dict,
102111
anc_dependencies: list,
103-
) -> tuple[bool, bool, xr.Dataset | None]:
112+
) -> tuple[bool, bool, bool, xr.Dataset | None, Path | None]:
104113
"""
105114
Determine what corrections are needed and prepare oxygen dataset if required.
106115
@@ -130,7 +139,9 @@ def _prepare_corrections(
130139
# Default values - no corrections needed
131140
sputtering_correction = False
132141
bootstrap_correction = False
142+
flux_correction = False
133143
o_map_dataset = None
144+
flux_factors: None | Path = None
134145

135146
# Sputtering and bootstrap corrections are only applied to hydrogen ENA data
136147
# Guard against recursion: don't process oxygen for oxygen maps
@@ -145,7 +156,24 @@ def _prepare_corrections(
145156
sputtering_correction = True
146157
bootstrap_correction = True
147158

148-
return sputtering_correction, bootstrap_correction, o_map_dataset
159+
if "raw" not in map_descriptor.principal_data:
160+
flux_correction = True
161+
try:
162+
flux_factors = next(
163+
x for x in anc_dependencies if "esa-eta-fit-factors" in str(x)
164+
)
165+
except StopIteration:
166+
raise ValueError(
167+
"No flux correction factor file found in ancillary dependencies"
168+
) from None
169+
170+
return (
171+
sputtering_correction,
172+
bootstrap_correction,
173+
flux_correction,
174+
o_map_dataset,
175+
flux_factors,
176+
)
149177

150178

151179
# =============================================================================
@@ -664,7 +692,9 @@ def calculate_all_rates_and_intensities(
664692
dataset: xr.Dataset,
665693
sputtering_correction: bool = False,
666694
bootstrap_correction: bool = False,
695+
flux_correction: bool = False,
667696
o_map_dataset: xr.Dataset | None = None,
697+
flux_factors: Path | None = None,
668698
) -> xr.Dataset:
669699
"""
670700
Calculate rates and intensities with proper error propagation.
@@ -679,8 +709,13 @@ def calculate_all_rates_and_intensities(
679709
bootstrap_correction : bool, optional
680710
Whether to apply bootstrap corrections to intensities.
681711
Default is False.
712+
flux_correction : bool, optional
713+
Whether to apply flux corrections to intensities.
714+
Default is False.
682715
o_map_dataset : xr.Dataset, optional
683716
Dataset specifically for oxygen, needed for sputtering corrections.
717+
flux_factors : Path, optional
718+
Path to flux factor file for flux corrections.
684719
685720
Returns
686721
-------
@@ -705,7 +740,13 @@ def calculate_all_rates_and_intensities(
705740
if bootstrap_correction:
706741
dataset = calculate_bootstrap_corrections(dataset)
707742

708-
# Step 6: Clean up intermediate variables
743+
# Optional Step 6: Calculate flux corrections
744+
if flux_correction:
745+
if flux_factors is None:
746+
raise ValueError("Flux factors file must be provided for flux corrections")
747+
dataset = calculate_flux_corrections(dataset, flux_factors)
748+
749+
# Step 7: Clean up intermediate variables
709750
dataset = cleanup_intermediate_variables(dataset)
710751

711752
return dataset
@@ -1084,6 +1125,56 @@ def calculate_bootstrap_corrections(dataset: xr.Dataset) -> xr.Dataset:
10841125
return dataset
10851126

10861127

1128+
def calculate_flux_corrections(dataset: xr.Dataset, flux_factors: Path) -> xr.Dataset:
1129+
"""
1130+
Calculate flux corrections for intensities.
1131+
1132+
Uses the shared ena maps ``PowerLawFluxCorrector`` class to do the
1133+
correction calculations.
1134+
1135+
Parameters
1136+
----------
1137+
dataset : xr.Dataset
1138+
Dataset with count rates, geometric factors, and center energies.
1139+
flux_factors : Path
1140+
Path to the eta flux factor file to use for corrections. Read in as
1141+
an ancillary file in the preprocessing step.
1142+
1143+
Returns
1144+
-------
1145+
xr.Dataset
1146+
Dataset with calculated flux-corrected intensities and their
1147+
uncertainties for the specified species.
1148+
"""
1149+
logger.info("Applying flux corrections")
1150+
1151+
# Flux correction
1152+
corrector = PowerLawFluxCorrector(flux_factors)
1153+
# FluxCorrector works on (energy, :) arrays, so we need to flatten the map
1154+
# spatial dimensions for the correction and then reshape back after.
1155+
input_shape = dataset["ena_intensity"].shape[1:] # Exclude epoch dimension
1156+
intensity = dataset["ena_intensity"].values[0].reshape(len(dataset["energy"]), -1)
1157+
stat_uncert = (
1158+
dataset["ena_intensity_stat_uncert"]
1159+
.values[0]
1160+
.reshape(len(dataset["energy"]), -1)
1161+
)
1162+
corrected_intensity, corrected_stat_unc = corrector.apply_flux_correction(
1163+
intensity,
1164+
stat_uncert,
1165+
dataset["energy"].data,
1166+
)
1167+
# Add the size 1 epoch dimension back in to the corrected fluxes.
1168+
dataset["ena_intensity"].data = corrected_intensity.reshape(input_shape)[
1169+
np.newaxis, ...
1170+
]
1171+
dataset["ena_intensity_stat_uncert"].data = corrected_stat_unc.reshape(input_shape)[
1172+
np.newaxis, ...
1173+
]
1174+
1175+
return dataset
1176+
1177+
10871178
def cleanup_intermediate_variables(dataset: xr.Dataset) -> xr.Dataset:
10881179
"""
10891180
Remove intermediate variables that were only needed for calculations.

imap_processing/tests/lo/test_lo_l2.py

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Comprehensive test suite for IMAP-Lo L2 data processing."""
22

3+
from pathlib import Path
34
from unittest.mock import Mock, patch
45

56
import numpy as np
@@ -23,6 +24,7 @@
2324
calculate_backgrounds,
2425
calculate_bootstrap_corrections,
2526
calculate_efficiency_corrected_quantities,
27+
calculate_flux_corrections,
2628
calculate_intensities,
2729
calculate_rates,
2830
calculate_sputtering_corrections,
@@ -532,6 +534,56 @@ def sample_dataset_with_bootstrap_data():
532534
return dataset
533535

534536

537+
@pytest.fixture
538+
def lo_flux_factors_file():
539+
"""Path to the LO flux factors test file."""
540+
# Use the actual test data file from the ena_maps test data
541+
test_data_path = Path(__file__).parent.parent / "ena_maps" / "data"
542+
return test_data_path / "imap_lo_esa-eta-fit-factors_20240101_v001.csv"
543+
544+
545+
@pytest.fixture
546+
def sample_dataset_with_intensities():
547+
"""Create a dataset with intensities for flux correction testing."""
548+
n_energy = 7
549+
n_lon, n_lat = 6, 4 # Small for testing
550+
551+
# Create realistic energy values matching the flux factors file
552+
energy_values = np.array([16.35, 30.56, 56.4, 105, 199.8, 407.5, 795.3])
553+
554+
coords = {
555+
"epoch": [8.1794907049e17],
556+
"energy": energy_values,
557+
"longitude": np.linspace(0, 360, n_lon, endpoint=False),
558+
"latitude": np.linspace(-90, 90, n_lat),
559+
}
560+
561+
# Create intensity values with some spatial and energy structure
562+
intensity_values = np.ones((1, n_energy, n_lon, n_lat))
563+
for i in range(n_energy):
564+
# Power law: I = I0 * (E/E0)^(-2.0)
565+
intensity_values[0, i, :, :] = 1e6 * (energy_values[i] / 100.0) ** (-2.0)
566+
567+
# Add some spatial structure
568+
for j in range(n_lon):
569+
for k in range(n_lat):
570+
intensity_values[0, :, j, k] *= 1.0 + 0.1 * np.sin(j) * np.cos(k)
571+
572+
dataset = xr.Dataset(coords=coords)
573+
dataset["ena_intensity"] = (
574+
("epoch", "energy", "longitude", "latitude"),
575+
intensity_values,
576+
)
577+
578+
# Add statistical uncertainties (10% of intensity)
579+
dataset["ena_intensity_stat_uncert"] = (
580+
("epoch", "energy", "longitude", "latitude"),
581+
intensity_values * 0.1,
582+
)
583+
584+
return dataset
585+
586+
535587
# =============================================================================
536588
# UNIT TESTS FOR INDIVIDUAL FUNCTIONS
537589
# =============================================================================
@@ -1002,6 +1054,129 @@ def test_calculate_backgrounds_zero_exposure(self):
10021054
assert np.all(np.isinf(result["bg_rates_stat_uncert"].values))
10031055

10041056

1057+
class TestCalculateFluxCorrections:
1058+
"""Tests for the calculate_flux_corrections function."""
1059+
1060+
def test_calculate_flux_corrections_basic(
1061+
self, sample_dataset_with_intensities, lo_flux_factors_file
1062+
):
1063+
"""Test basic flux correction calculation."""
1064+
# Make a copy to avoid modifying the original fixture
1065+
original_dataset = sample_dataset_with_intensities.copy(deep=True)
1066+
1067+
# Run flux correction
1068+
result = calculate_flux_corrections(original_dataset, lo_flux_factors_file)
1069+
1070+
# Verify that the function returns a dataset
1071+
assert isinstance(result, xr.Dataset)
1072+
1073+
# Verify that intensity variables are present
1074+
assert "ena_intensity" in result.data_vars
1075+
assert "ena_intensity_stat_uncert" in result.data_vars
1076+
1077+
# Verify that data shape is preserved
1078+
original_shape = sample_dataset_with_intensities["ena_intensity"].shape
1079+
assert result["ena_intensity"].shape == original_shape
1080+
1081+
# Check that corrections were applied by comparing to the original fixture
1082+
# (not the potentially modified copy)
1083+
original_intensity = sample_dataset_with_intensities["ena_intensity"].values
1084+
corrected_intensity = result["ena_intensity"].values
1085+
1086+
# Check for meaningful differences
1087+
relative_diff = np.abs(
1088+
(corrected_intensity - original_intensity) / original_intensity
1089+
)
1090+
max_relative_diff = np.max(relative_diff)
1091+
# Should have at least 10% change somewhere
1092+
assert max_relative_diff > 0.1, (
1093+
f"Max relative difference was only {max_relative_diff}"
1094+
)
1095+
1096+
# Verify that uncertainties were also corrected
1097+
original_uncert = sample_dataset_with_intensities[
1098+
"ena_intensity_stat_uncert"
1099+
].values
1100+
corrected_uncert = result["ena_intensity_stat_uncert"].values
1101+
uncert_relative_diff = np.abs(
1102+
(corrected_uncert - original_uncert) / original_uncert
1103+
)
1104+
max_uncert_diff = np.max(uncert_relative_diff)
1105+
# Should have at least 10% change in uncertainties too
1106+
assert max_uncert_diff > 0.1, (
1107+
f"Max uncertainty relative difference was only {max_uncert_diff}"
1108+
)
1109+
1110+
def test_calculate_flux_corrections_preserves_other_vars(
1111+
self, sample_dataset_with_intensities, lo_flux_factors_file
1112+
):
1113+
"""Test that flux correction preserves other variables in the dataset."""
1114+
# Add an extra variable to the dataset
1115+
sample_dataset_with_intensities["extra_var"] = (("energy",), np.ones(7))
1116+
1117+
result = calculate_flux_corrections(
1118+
sample_dataset_with_intensities, lo_flux_factors_file
1119+
)
1120+
1121+
# Verify that other variables are preserved
1122+
assert "extra_var" in result.data_vars
1123+
np.testing.assert_array_equal(
1124+
result["extra_var"].values,
1125+
sample_dataset_with_intensities["extra_var"].values,
1126+
)
1127+
1128+
def test_calculate_flux_corrections_energy_dimension_handling(
1129+
self, lo_flux_factors_file
1130+
):
1131+
"""Test that flux correction properly handles energy dimension reshaping."""
1132+
# Create a dataset with different spatial dimensions
1133+
n_energy = 7
1134+
n_x, n_y = 12, 8 # Different spatial dimensions
1135+
1136+
energy_values = np.array([16.35, 30.56, 56.4, 105, 199.8, 407.5, 795.3])
1137+
1138+
coords = {
1139+
"epoch": [8.1794907049e17],
1140+
"energy": energy_values,
1141+
"x": np.arange(n_x),
1142+
"y": np.arange(n_y),
1143+
}
1144+
1145+
# Create intensity values with energy-dependent structure (power law)
1146+
intensity_values = np.ones((1, n_energy, n_x, n_y))
1147+
for i in range(n_energy):
1148+
intensity_values[0, i, :, :] = 1e6 * (energy_values[i] / 100.0) ** (-2.0)
1149+
uncert_values = intensity_values * 0.1
1150+
1151+
original_dataset = xr.Dataset(coords=coords)
1152+
original_dataset["ena_intensity"] = (
1153+
("epoch", "energy", "x", "y"),
1154+
intensity_values.copy(),
1155+
)
1156+
original_dataset["ena_intensity_stat_uncert"] = (
1157+
("epoch", "energy", "x", "y"),
1158+
uncert_values.copy(),
1159+
)
1160+
1161+
# Run flux correction on a copy
1162+
dataset_copy = original_dataset.copy(deep=True)
1163+
result = calculate_flux_corrections(dataset_copy, lo_flux_factors_file)
1164+
1165+
# Verify shape is preserved
1166+
assert result["ena_intensity"].shape == (1, n_energy, n_x, n_y)
1167+
assert result["ena_intensity_stat_uncert"].shape == (1, n_energy, n_x, n_y)
1168+
1169+
# Verify corrections were applied by checking for meaningful differences
1170+
original_values = original_dataset["ena_intensity"].values
1171+
corrected_values = result["ena_intensity"].values
1172+
relative_diff = np.abs((corrected_values - original_values) / original_values)
1173+
max_relative_diff = np.max(relative_diff)
1174+
# Should have at least 10% change somewhere (flux corrections are significant)
1175+
assert max_relative_diff > 0.1, (
1176+
f"Max relative difference was only {max_relative_diff}"
1177+
)
1178+
1179+
10051180
class TestCalculateSputteringCorrections:
10061181
"""Tests for the calculate_sputtering_corrections function."""
10071182

@@ -1970,11 +2145,13 @@ def test_calculate_all_rates_and_intensities_complete(self):
19702145
class TestIntegrationWithMocks:
19712146
"""Integration tests using mocked external dependencies."""
19722147

1973-
def test_lo_l2_integration_minimal(self, minimal_pset_for_species):
2148+
def test_lo_l2_integration_minimal(
2149+
self, minimal_pset_for_species, lo_flux_factors_file
2150+
):
19742151
"""Test the main lo_l2 function with minimal mocking."""
19752152
# Test with hydrogen data
19762153
sci_dependencies = {"imap_lo_l1c_pset": [minimal_pset_for_species]}
1977-
anc_dependencies = []
2154+
anc_dependencies = [lo_flux_factors_file] # Include flux factors file
19782155
descriptor = "l090-ena-h-sf-nsp-ram-hae-6deg-3mo"
19792156

19802157
# Mock the complex external dependencies to return simple results

0 commit comments

Comments
 (0)