diff --git a/imap_processing/tests/ultra/unit/test_ultra_l1b_culling.py b/imap_processing/tests/ultra/unit/test_ultra_l1b_culling.py index 2fbba6c56..2277e489b 100644 --- a/imap_processing/tests/ultra/unit/test_ultra_l1b_culling.py +++ b/imap_processing/tests/ultra/unit/test_ultra_l1b_culling.py @@ -27,6 +27,7 @@ flag_low_voltage, flag_rates, flag_scattering, + flag_statistical_outliers, get_binned_energy_ranges, get_binned_spins_edges, get_de_rejection_mask, @@ -34,6 +35,7 @@ get_energy_histogram, get_energy_range_flags, get_n_sigma, + get_poisson_stats, get_pulses_per_spin, get_spin_data, get_valid_earth_angle_events, @@ -337,7 +339,6 @@ def test_flag_low_voltage(test_data): "leftdeflection_v": np.full(n_spins, 1.5), } ) - flagged = 65535 spins = np.arange(n_spins) spin_bin_size = 5 spin_period = np.full(n_spins, 15.0) @@ -353,14 +354,14 @@ def test_flag_low_voltage(test_data): # Check quality flag shape assert quality_flags.shape == (len(spin_tbin_edges) - 1,) # Check that every spin is flagged for low voltage - assert np.all(quality_flags == flagged) + assert np.all(quality_flags) # Set only the first spin to be below threshold mock_status_dataset["rightdeflection_v"].data[1:] += 5000 mock_status_dataset["leftdeflection_v"].data[1:] += 5000 quality_flags = flag_low_voltage(spin_tbin_edges, mock_status_dataset) # Check that only the first spin is flagged for low voltage - assert np.all(quality_flags[0] == flagged) + assert np.all(quality_flags[0]) # The rest should not be flagged assert np.all(quality_flags[1:] == 0) @@ -388,9 +389,7 @@ def test_flag_low_voltage_incomplete_bins(test_data): # check quality flag assert quality_flags.shape == (n_spins // spin_bin_size,) - # Check that every spin is flagged for low voltage - flagged = 65535 - assert np.all(quality_flags == flagged) + assert np.all(quality_flags) def test_expand_bin_flags_to_spins(caplog): @@ -479,9 +478,7 @@ def test_validate_voltage_cull(): xspin.spin_start_time.values, spin_bin_size, ) - lv_flags = flag_low_voltage( - spin_tbin_edges, status_ds, lv_threshold, low_voltage_flag=1 - ) + lv_flags = flag_low_voltage(spin_tbin_edges, status_ds, lv_threshold) assert np.array_equal(lv_flags, validation_low_voltage_qf) @@ -644,10 +641,10 @@ def test_flag_high_energy(): energy_range_edges = np.array([3, 5, 7, 18, 25]) # Example energy bin edges # Spin bin 1 (events 0-3) 4 events fall within the culling energy bin # - This is above all the energy thresholds except the second one (10), - # so it should be flagged for all energy ranges except flag #2 + # so it should be flagged for all energy ranges except flag the 2nd # Spin bin 2 (events 4-7) only 1 event falls within the culling energy bin # - This is above the lowest energy threshold (1) but below the rest, so - # it should only be flagged with the last energy range flag (#3) + # it should only be flagged at the last energy range # Spin bin 3 (events 8-11) No events fall within the culling energy bin, # so it should not be flagged for any energy range energy = np.array([17, 16, 12, 15, 5, 8, 4, 6, 4, 1, 22, 20]) @@ -666,26 +663,31 @@ def test_flag_high_energy(): spin_tbin_edges = np.arange( start=0, stop=len(energy) + 1, step=4 ) # create spin bins of 4 seconds - energy_range_flags = get_energy_range_flags(energy_range_edges) quality_flags = flag_high_energy( de_dataset, spin_tbin_edges, energy_range_edges, - energy_range_flags, + None, cull_thresholds, 90, ) # check shape - assert len(quality_flags) == len(spin_tbin_edges) - 1 + np.testing.assert_array_equal( + quality_flags.shape, (len(energy_range_edges) - 1, len(spin_tbin_edges) - 1) + ) # Assert that the first spin bin is flagged for high energy for all energy ranges # except the second one - assert quality_flags[0] == (2**0 | 2**2 | 2**3) + assert quality_flags[0, 0] + assert not quality_flags[1, 0] + assert quality_flags[2, 0] + assert quality_flags[3, 0] # Assert that the second spin bin is only flagged for high energy for the last - # energy range - assert quality_flags[1] == 2**3 - # Assert that the third spin bin is not flagged for any energy range - assert quality_flags[2] == 0 + # # energy range + assert quality_flags[3, 1] + assert not np.any(quality_flags[0:3, 1]) + # # Assert that the third spin bin is not flagged for any energy range + assert not np.any(quality_flags[:, 2]) @pytest.mark.external_test_data @@ -720,23 +722,122 @@ def test_validate_high_energy_cull(): intervals, _, _ = build_energy_bins() # Get the energy ranges energy_ranges = get_binned_energy_ranges(intervals) - flags = get_energy_range_flags(energy_ranges) e_flags = flag_high_energy( - de_ds, spin_tbin_edges, energy_ranges, flags, mock_thresholds - ) - # The ULTRA IT flags are shaped n_energy_ranges, spin_bin while the SDC - # is only spin_bin but different flags are set for different energy ranges, so we - # need to check each energy separately - # We also need to invert the expected flags since the ULTRA IT mask is True - # for good spins (counts below threshold) while the SDC quality flags are set - # for bad spins (counts exceed threshold). - for i in range(expected_qf.shape[0]): - np.testing.assert_array_equal( - (e_flags & 2**i) > 0, - ~expected_qf[i, :].astype(bool), - err_msg=f"High energy flag mismatch for energy range {i} with edges" - f" {energy_ranges[i]}-{energy_ranges[i + 1]}", - ) + de_ds, spin_tbin_edges, energy_ranges, None, mock_thresholds + ) + np.testing.assert_array_equal(e_flags, ~expected_qf.astype(bool)) + + +def test_flag_statistical_outliers(): + """Tests flag_statistical_outliers function.""" + energy_range_edges = np.array([3, 5, 7, 18, 25]) # Example energy bin edges + n_spin_bins = 12 + spin_step = 7 + energy = np.full(spin_step * n_spin_bins, 0) + # Make sure there are at least 3 other bins with counts in each energy bin so that + # the statistics can be calculated. + energy[::spin_step] = 3 + energy[1::spin_step] = 5 + energy[2::spin_step] = 7 + energy[3::spin_step] = 18 + # Make the last spin bin have higher counts. It should get flagged as an outlier. + energy[-spin_step:] = 23 + + de_dataset = xr.Dataset( + { + "de_event_met": ("epoch", np.arange(len(energy))), + "energy_spacecraft": ("epoch", energy), + "quality_outliers": ("epoch", np.full(len(energy), 0)), + "quality_scattering": ("epoch", np.full(len(energy), 0)), + "ebin": ("epoch", np.full(len(energy), 10)), + } + ) + spin_tbin_edges = np.arange( + start=0, stop=len(energy) + 1, step=spin_step + ) # create spin bins of 7 seconds + quality_flags, convergence, iterations, std_diff = flag_statistical_outliers( + de_dataset, + spin_tbin_edges, + energy_range_edges, + np.zeros((len(energy_range_edges) - 1, len(spin_tbin_edges) - 1), dtype=bool), + combine_flags_across_energy_bins=True, + ) + + # check shape + np.testing.assert_array_equal( + quality_flags.shape, (len(energy_range_edges) - 1, len(spin_tbin_edges) - 1) + ) + # check that none of the flags are set except for the last spin bin + # since combine_flags_across_energy_bins is True, the entire last spin bin should + # be flagged even though only one energy bin had high counts + expected_flags = np.zeros( + (len(energy_range_edges) - 1, len(spin_tbin_edges) - 1), dtype=bool + ) + expected_flags[:, -1] = True + np.testing.assert_array_equal(quality_flags, expected_flags) + # all energy bins should have converged + # The first 2 didn't have enough events to calculate statistics, but they should + # still be marked as converged + assert np.all(convergence) + # All energy bins should have iterated 1 time except the last one which should have + # iterated twice. + assert np.all(iterations[:-1] == 1) + assert iterations[-1] == 2 + # Check that all std_diff values were set (not zero) except the last one + assert np.all(std_diff[:-1] != 0) + assert std_diff[-1] == 0 + + +def test_flag_statistical_outliers_invalid_events(): + """Tests flag_statistical_outliers function when there are no valid events.""" + energy_range_edges = np.array([3, 5, 7, 18, 25]) # Example energy bin edges + energy = np.arange(25) + de_dataset = xr.Dataset( + { + "de_event_met": ("epoch", np.arange(len(energy))), + "energy_spacecraft": ("epoch", energy), + "quality_outliers": ("epoch", np.full(len(energy), 0)), + "quality_scattering": ("epoch", np.full(len(energy), 0)), + "ebin": ("epoch", np.full(len(energy), 10)), + } + ) + spin_tbin_edges = np.arange( + start=0, stop=len(energy) + 1, step=5 + ) # create spin bins of 5 seconds + mask = np.ones((len(energy_range_edges) - 1, len(spin_tbin_edges) - 1), dtype=bool) + quality_flags, convergence, iterations, std_diff = flag_statistical_outliers( + de_dataset, + spin_tbin_edges, + energy_range_edges, + mask, + ) + # check that all flags are set because there are no valid events in any energy bin + # so it fails the stat outlier check by default. + np.testing.assert_array_equal( + quality_flags, np.ones_like(quality_flags, dtype=bool) + ) + # check that all energy bins are marked as converged (no valid events is not a + # failure case for convergence since we just can't calculate statistics. + assert np.all(convergence) + # check that there were no iterations + assert np.sum(iterations) == 0 + # Check that std_diff is all invalid (-1) + assert np.all(std_diff == -1) + + +def test_get_poisson_stats(): + """Tests get_poisson_stats function.""" + counts = np.full(20, 0) + counts[-1] = 100 # Make the last bin have high counts + std_diff, outlier_mask = get_poisson_stats(counts) + # std_diff should be counts/sqrt(counts) = sqrt(counts) + assert ( + std_diff == np.std(counts) / np.sqrt(5) - 1 + ) # std_diff should be counts/sqrt(counts) = sqrt(counts) + assert ( + np.sum(outlier_mask) == 1 + ) # Only the last bin should be flagged as an outlier + assert outlier_mask[-1] def test_get_energy_range_flags(): diff --git a/imap_processing/ultra/constants.py b/imap_processing/ultra/constants.py index ce759f859..bb6c52b79 100644 --- a/imap_processing/ultra/constants.py +++ b/imap_processing/ultra/constants.py @@ -204,3 +204,7 @@ class UltraConstants: ) # Use the channel defined below to determine which spins are contaminated HIGH_ENERGY_CULL_CHANNEL = 4 + # Number of iterations to perform for statistical outlier culling. + STAT_CULLING_N_ITER = 5 + # Sigma threshold to use for statistical outlier culling. + STAT_CULLING_STD_THRESHOLD = 0.05 diff --git a/imap_processing/ultra/l1b/extendedspin.py b/imap_processing/ultra/l1b/extendedspin.py index a68f0fb6a..41eed8642 100644 --- a/imap_processing/ultra/l1b/extendedspin.py +++ b/imap_processing/ultra/l1b/extendedspin.py @@ -4,7 +4,6 @@ import xarray as xr from numpy.typing import NDArray -from imap_processing.quality_flags import ImapRatesUltraFlags from imap_processing.ultra.constants import UltraConstants from imap_processing.ultra.l1b.ultra_l1b_culling import ( count_rejected_events_per_spin, @@ -15,6 +14,7 @@ flag_imap_instruments, flag_low_voltage, flag_rates, + flag_statistical_outliers, get_binned_energy_ranges, get_binned_spins_edges, get_energy_histogram, @@ -75,27 +75,32 @@ def calculate_extendedspin( spin, spin_period, spin_starttime, spin_bin_size ) voltage_qf = flag_low_voltage(spin_tbin_edges, status_dataset) - voltage_qf = expand_bin_flags_to_spins(len(spin), voltage_qf, spin_bin_size) # Get energy bins used at l1c intervals, _, _ = build_energy_bins() # Get the energy ranges energy_ranges = get_binned_energy_ranges(intervals) energy_bin_flags = get_energy_range_flags(energy_ranges) - # Calculate the high energy quality flags using the de dataset with low voltage - # events removed. Use the same spin and energy bins that - # were used for low voltage flags to maintain consistency in the flags. - valid_voltage_spins = spin[np.where(voltage_qf == 0)] - valid_de_spins = np.isin(de_dataset["spin"].values, valid_voltage_spins) - de_dataset_filtered = de_dataset.isel(epoch=valid_de_spins) + # Calculate the high energy quality flags energy_thresholds = UltraConstants.HIGH_ENERGY_CULL_THRESHOLDS high_energy_qf = flag_high_energy( - de_dataset_filtered, + de_dataset, spin_tbin_edges, energy_ranges, - energy_bin_flags, + voltage_qf, energy_thresholds, instrument_id, ) + # Combine high energy and voltage flags to use for statistical outlier flagging. + mask = ( + voltage_qf[np.newaxis, :] | high_energy_qf + ) # Shape (n_energy_bins, n_spins_bins) + stat_outliers_qf, _, _, _ = flag_statistical_outliers( + de_dataset, + spin_tbin_edges, + energy_ranges, + mask, + instrument_id, + ) # Get the number of pulses per spin. pulses = get_pulses_per_spin(aux_dataset, rates_dataset) @@ -132,8 +137,25 @@ def calculate_extendedspin( stop_per_spin[valid] = pulses.stop_per_spin[idx[valid]] coin_per_spin[valid] = pulses.coin_per_spin[idx[valid]] + # high energy and statistical outlier flags are energy dependent boolean arrays + # with shape (n_energy_bins, n_spin_bins). We want to collapse the energy dimension + # using a bitwise OR to get a single boolean flag per spin. + high_energy_qf = np.bitwise_or.reduce( + high_energy_qf * energy_bin_flags[:, np.newaxis], axis=0 + ) + stat_outliers_qf = np.bitwise_or.reduce( + stat_outliers_qf * energy_bin_flags[:, np.newaxis], axis=0 + ) + # Low voltage flag is shape (n_spin_bins,) but we want to convert from a boolean + # to a bitwise flag to be consistent with the other flags, where each spin that + # is flagged will have the bitflag of all the energy flags combined. + voltage_qf = voltage_qf * np.bitwise_or.reduce(energy_bin_flags) # Expand binned quality flags to individual spins. high_energy_qf = expand_bin_flags_to_spins(len(spin), high_energy_qf, spin_bin_size) + voltage_qf = expand_bin_flags_to_spins(len(spin), voltage_qf, spin_bin_size) + stat_outliers_qf = expand_bin_flags_to_spins( + len(spin), stat_outliers_qf, spin_bin_size + ) # account for rates spins which are not in the direct event spins extendedspin_dict["start_pulses_per_spin"] = start_per_spin extendedspin_dict["stop_pulses_per_spin"] = stop_per_spin @@ -146,9 +168,7 @@ def calculate_extendedspin( extendedspin_dict["quality_low_voltage"] = voltage_qf # shape (nspin,) # TODO calculate flags for high energy (SEPS) and statistics culling # Initialize these flags to NONE for now. - extendedspin_dict["quality_statistics"] = np.full_like( - voltage_qf, ImapRatesUltraFlags.NONE.value, np.uint16 - ) # shape (nspin,) + extendedspin_dict["quality_statistics"] = stat_outliers_qf # shape (nspin,) extendedspin_dict["quality_high_energy"] = high_energy_qf # shape (nspin,) # Add an array of flags for each energy bin. Shape: (n_energy_bins) extendedspin_dict["energy_range_flags"] = energy_bin_flags diff --git a/imap_processing/ultra/l1b/ultra_l1b_culling.py b/imap_processing/ultra/l1b/ultra_l1b_culling.py index 946364923..a803de862 100644 --- a/imap_processing/ultra/l1b/ultra_l1b_culling.py +++ b/imap_processing/ultra/l1b/ultra_l1b_culling.py @@ -543,7 +543,7 @@ def get_energy_and_spin_dependent_rejection_mask( goodtimes_dataset : xr.Dataset Dataset containing valid spins and energy bin flags. energy : np.ndarray - The particle energy. + The particle energy at each direct event. spin_number : np.ndarray Spin number at each direct event. @@ -629,7 +629,6 @@ def flag_low_voltage( spin_tbin_edges: NDArray, status_dataset: xr.Dataset, voltage_threshold: float = UltraConstants.LOW_VOLTAGE_CULL_THRESHOLD, - low_voltage_flag: int = 65535, # default is max uint16 ) -> NDArray: """ Flag low voltage events. @@ -642,19 +641,15 @@ def flag_low_voltage( Status dataset containing voltage information. voltage_threshold : float Voltage threshold below which to flag low voltage events. - low_voltage_flag : int - The flag value to set for low voltage events. Returns ------- quality_flags : NDArray - Quality flags. + Boolean quality flags shaped (n_spin_bins,). """ spin_bin_size = len(spin_tbin_edges) - 1 # initialize all spins to have no low voltage flag - quality_flags = np.full( - spin_bin_size, ImapRatesUltraFlags.NONE.value, dtype=np.uint16 - ) + quality_flags = np.zeros(spin_bin_size, dtype=bool) # Get the min voltage across both deflection plate at each epoch min_voltage = np.minimum( status_dataset["rightdeflection_v"].data, @@ -675,7 +670,9 @@ def flag_low_voltage( valid_bin_inds = (lv_spin_inds >= 0) & (lv_spin_inds < spin_bin_size) lv_spin_inds = lv_spin_inds[valid_bin_inds] # For each low voltage ind, flag the corresponding flag - quality_flags[lv_spin_inds] = low_voltage_flag + quality_flags[lv_spin_inds] = True + + # TODO add log summary. return quality_flags @@ -684,7 +681,7 @@ def flag_high_energy( de_dataset: xr.Dataset, spin_tbin_edges: NDArray, energy_ranges: NDArray, - energy_range_flags: np.ndarray, + mask: NDArray = None, energy_thresholds: np.ndarray = UltraConstants.HIGH_ENERGY_CULL_THRESHOLDS, sensor_id: int = 90, ) -> NDArray: @@ -699,8 +696,10 @@ def flag_high_energy( Edges of the spin time bins. energy_ranges : numpy.ndarray Array of energy range edges. - energy_range_flags : numpy.ndarray - Array of quality flag values corresponding to each energy range. + mask : numpy.ndarray, optional + Mask indicating which events to consider for high energy flagging + (e.g., after low voltage culling). True indicates the spin bins that should + NOT be considered for high energy flagging. energy_thresholds : numpy.ndarray Array of count thresholds for flagging high energy events corresponding to each energy range. @@ -710,48 +709,254 @@ def flag_high_energy( Returns ------- quality_flags : numpy.ndarray - Quality flags. + Boolean quality flags shaped (n_energy_bins, n_spin_bins). """ + # expand energy thresholds to have shape (n_energy_bins, 1) for comparison with + # the counts per spin + energy_thresholds = energy_thresholds[:, np.newaxis] # Shape (n_energy_bins, 1) cull_channel = UltraConstants.HIGH_ENERGY_CULL_CHANNEL - valid_events_per_energy = get_valid_events_per_energy_range( - de_dataset, energy_ranges, UltraConstants.EARTH_ANGLE_45_THRESHOLD, sensor_id - ) - # check to make sure the number of energy ranges matches the number of energy range - # flags - num_e_ranges = valid_events_per_energy.shape[0] - if num_e_ranges != len(energy_range_flags) or num_e_ranges != len( - energy_thresholds - ): - raise ValueError( - f"Number of energy ranges ({num_e_ranges}) does not match number of energy" - f" range flags ({len(energy_range_flags)}) or expected number of " - f"energy range thresholds ({len(energy_thresholds)})." - ) - if cull_channel >= num_e_ranges: + n_energy_bins = len(energy_thresholds) + if cull_channel >= n_energy_bins: raise ValueError( f"HIGH_ENERGY_CULL_CHANNEL ({cull_channel}) is out of bounds" - f" for {num_e_ranges} energy ranges." + f" for {n_energy_bins} energy ranges." ) # Initialize all spin bins to have no high energy flag spin_bin_size = len(spin_tbin_edges) - 1 - quality_flags = np.full( - spin_bin_size, ImapRatesUltraFlags.NONE.value, dtype=np.uint16 - ) + quality_flags = np.zeros((n_energy_bins, spin_bin_size), dtype=bool) # Get valid events and counts at each spin bin for the # designated culling channel. - cull_channel_events = valid_events_per_energy[cull_channel] - # get each valid event count per spin bin for the culling channel - cull_channel_counts = np.histogram( - de_dataset["de_event_met"].values[cull_channel_events], spin_tbin_edges - )[0] - # loop through each energy range - for flag, e_threshold in zip(energy_range_flags, energy_thresholds, strict=False): - quality_flags[cull_channel_counts >= e_threshold] |= flag + de_counts = get_valid_de_count_summary( + de_dataset, energy_ranges, spin_tbin_edges, sensor_id + ) + cull_channel_counts = de_counts[cull_channel] + # flag spins where the counts in the cull channel exceed the threshold for that + # energy range + flagged = ( + cull_channel_counts[np.newaxis, :] >= energy_thresholds + ) # (n_energy_bins, n_spin_bins) + + if mask is not None: + quality_flags[:, ~mask] = flagged[:, ~mask] + else: + quality_flags = flagged + # TODO add log summary. E.g Tim's hi goodtimes code return quality_flags +def flag_statistical_outliers( + de_dataset: xr.Dataset, + spin_tbin_edges: NDArray, + energy_ranges: NDArray, + mask: NDArray, + sensor_id: int = 90, + n_iterations: int = UltraConstants.STAT_CULLING_N_ITER, + std_threshold: float = UltraConstants.STAT_CULLING_STD_THRESHOLD, + combine_flags_across_energy_bins: bool = True, +) -> tuple[NDArray, NDArray, NDArray, NDArray]: + """ + Flag statistical outlier events based on count rates per spin. + + After low voltage and high energy spins have been flagged, there still appears to + be some time dependency in the signal. This algorithm identifies those outliers. + + Iterative algorithm to identify areas consistent with Poisson statistics + For each energy range: + 1. Flag where there are less than 3 bins with counts + 2. Calculate the mean (μ) and standard deviation (σ) of the counts in each bin. + 3. Find bins where the counts, c, yield |(c-μ)/σ|>3,  cull these bins + 4. Calculate ε=σ/√μ-1 + 5. If ε is less than a threshold value (0.05 for now) stop iterating + 6. If number of iterations exceeds threshold (5 for now), stop iterating + 7. Return to step 1 + + Parameters + ---------- + de_dataset : xr.Dataset + Direct event dataset. + spin_tbin_edges : numpy.ndarray + Edges of the spin time bins. + energy_ranges : numpy.ndarray + Array of energy range edges. + mask : numpy.ndarray + Mask indicating which events to consider for statistical outlier flagging. + This should be a 2d boolean array of shape (n_energy_bins, n_spin_bins) where + True indicates the spin bins that have been flagged in previous steps (e.g., + after low voltage and high energy culling) and should be excluded from the + outlier flagging process. + sensor_id : int + Sensor ID (e.g., 45 or 90). + n_iterations : int + Maximum number of iterations to perform for outlier flagging. + std_threshold : float + Threshold for standard deviation difference from Poisson stats to determine + convergence. + combine_flags_across_energy_bins : bool + Whether to link energy channels such that if a spin bin is flagged in any energy + channel, it is flagged in all energy channels. + + Returns + ------- + quality_stats : numpy.ndarray + Quality flags for statistical outliers, shaped (n_energy_bins, n_spin_bins). + convergence : numpy.ndarray + Boolean array of shape (n_energy_bins,) indicating whether the outlier flagging + converged for each energy bin. + iterations : numpy.ndarray + Array of shape (n_energy_bins,) indicating how many iterations were performed + for each energy bin. + std_diff : numpy.ndarray + Array of shape (n_energy_bins,) containing the final standard deviation + difference from Poisson stats for each energy bin. + """ + # Initialize all spin bins to have no outlier flag + spin_bin_size = len(spin_tbin_edges) - 1 + n_energy_bins = len(energy_ranges) - 1 + # make a copy of the mask to avoid modifying the original mask passed in + iter_mask = mask.copy() + quality_stats = np.zeros((n_energy_bins, spin_bin_size), dtype=bool) + # Initialize convergence array to keep track of poisson stats + convergence = np.full(n_energy_bins, False) + # Keep track of how many iterations we have done of flagging outliers and + # recalculating stats per energy bin + iterations = np.zeros(n_energy_bins) + # keep track of the standard deviation difference from poisson stats per energy bin + std_diff = np.zeros(n_energy_bins, dtype=float) + count_summary = get_valid_de_count_summary( + de_dataset, energy_ranges, spin_tbin_edges, sensor_id + ) # shape (n_energy_bins, n_spin_bins) + for e_idx in np.arange(n_energy_bins): + for it in range(n_iterations): + # only consider bins that are currently unflagged for this energy bin + counts = count_summary[e_idx, ~iter_mask[e_idx]] + # Step 1. check if any energy bins have less than 3 spin bins with counts. + # If so, flag all spins for that energy bin and skip to the next iteration + if np.sum(counts > 0) < 3: + quality_stats[e_idx] = True + convergence[e_idx] = True + std_diff[e_idx] = -1 + break + # Step 2. Check how close the data is to poisson stats + std_ratio, outlier_mask = get_poisson_stats(counts) + std_diff[e_idx] = std_ratio + # Step 3. Flag bins where the count is more than 3 standard deviations from + # the mean. + outlier_inds = np.where(~iter_mask[e_idx])[0][outlier_mask] + # Set the quality flag to True for the outlier inds + quality_stats[e_idx, outlier_inds] = True + # Also update the iter_mask to exclude the outlier bins for the next + # iteration + iter_mask[e_idx, outlier_inds] = True + iterations[e_idx] = it + 1 + # Check for convergence: if the standard deviation difference from + # poisson stats is below the threshold, then we can stop iterating for this + # energy bin + if std_ratio < std_threshold: + convergence[e_idx] = True + break + + if combine_flags_across_energy_bins: + # If a spin bin is flagged in any energy channel flag it in all energy channels + # Use np.any to check if a spin bin is flagged in any energy channel, + # then flag it in all energy channels + combined_mask = np.any(quality_stats, axis=0) # (n_spin_bins,) + quality_stats[:] = combined_mask # Broadcast to all energy bins + + # Recalculate convergence with the combined mask. + for e_idx in range(n_energy_bins): + if not convergence[e_idx]: + # Select counts that have not been flagged in any channel. + counts = count_summary[e_idx, ~combined_mask] + std_ratio, _ = get_poisson_stats(counts) + std_diff[e_idx] = std_ratio + if std_ratio < std_threshold: + convergence[e_idx] = True + + num_culled: int = np.sum(quality_stats) + logger.debug( + f"Statistical culling removed {num_culled} spin bins across {n_energy_bins}" + f" energy channels. Convergence: {convergence} after " + f"{iterations} iterations." + ) + + return quality_stats, convergence, iterations, std_diff + + +def get_poisson_stats(counts: NDArray) -> tuple[float, NDArray]: + """ + Calculate Poisson statistics for a given array of counts. + + For a perfect Poisson distribution, the standard deviation should equal + the square root of the mean. The std_ratio measures how far the observed + distribution deviates from this. + + Outliers are identified as bins where the counts deviate more than 3 + standard deviations from the mean. + + Parameters + ---------- + counts : numpy.ndarray + Array of counts per spin bin for a given energy range. + + Returns + ------- + std_ratio : float + Ratio of the observed standard deviation to the expected Poisson + standard deviation. + sub_mask : numpy.ndarray + Boolean array of the same length as counts. True where a bin is + a statistical outlier (more than 3 sigma from the mean). + """ + std = np.std(counts) + if std == 0: + # If std is 0, then all counts are the same. In this case, we can consider + # there to be no outliers and the distribution to perfectly match Poisson + return 0, np.zeros_like(counts, dtype=bool) + std_ratio = std / np.sqrt(np.mean(counts)) - 1 + sub_mask = np.abs((counts - np.mean(counts)) / std) > 3 + return std_ratio, sub_mask + + +def get_valid_de_count_summary( + de_dataset: xr.Dataset, + energy_ranges: NDArray, + spin_tbin_edges: NDArray, + sensor_id: int = 90, +) -> NDArray: + """ + Get a summary of valid counts per energy range and spin bin. + + Parameters + ---------- + de_dataset : xr.Dataset + Direct event dataset. + energy_ranges : numpy.ndarray + Array of energy range edges. + spin_tbin_edges : numpy.ndarray + Array of spin time bin edges. + sensor_id : int + Sensor ID (e.g., 45 or 90). + + Returns + ------- + counts : numpy.ndarray + A 2D array of counts per energy range and spin bin for valid events. + """ + valid_events = get_valid_events_per_energy_range( + de_dataset, energy_ranges, UltraConstants.EARTH_ANGLE_45_THRESHOLD, sensor_id + ) + counts = np.zeros((len(energy_ranges) - 1, len(spin_tbin_edges) - 1), dtype=float) + + for i in range(len(energy_ranges) - 1): + counts[i, :], _ = np.histogram( + de_dataset["de_event_met"].values[valid_events[i, :]], bins=spin_tbin_edges + ) + + return counts + + def get_valid_events_per_energy_range( de_dataset: xr.Dataset, energy_ranges: NDArray, earth_ang_45: float, sensor_id: int ) -> NDArray: