diff --git a/docs/matlab_differences.rst b/docs/matlab_differences.rst index 1eace0d6..43228640 100644 --- a/docs/matlab_differences.rst +++ b/docs/matlab_differences.rst @@ -148,3 +148,43 @@ roughly mean-centered. and will thus produce similar values to normal Pearson correlation. However, to avoid making any assumptions about the signal for any given channel / window, PyPREP defaults to normal Pearson correlation unless strict MATLAB equivalence is requested. + + +Differences in Robust Referencing +--------------------------------- + +During the robust referencing part of the pipeline, PREP tries to estimate a +"clean" average reference signal for the dataset, excluding any channels +flagged as noisy from contaminating the reference. The robust referencing +process is performed using the following logic: + +1) First, an initial pass of noisy channel detection is performed to identify + channels bad by NaN values, flat signal, or low SNR: the data is then + average-referenced excluding these channels. These channels are subsequently + marked as "unusable" and are excluded from any future average referencing. + +2) Noisy channel detection is performed on a copy of the re-referenced signal, + and any newly detected bad channels are added to the full set of channels + to be excluded from the reference signal. + +3) After noisy channel detection, all bad channels detected so far are + interpolated, and a new estimate of the robust average reference is + calculated using the mean signal of all good channels and all interpolated + bad channels (except those flagged as "unusable" during the first step). + +4) A fresh copy of the re-referenced signal from Step 1 is re-referenced using + the new reference signal calculated in Step 3. + +5) Steps 2 through 4 are repeated until either two iterations have passed and + no new noisy channels have been detected since the previous iteration, or + the maximum number of reference iterations has been exceeded (default: 4). + + +Exclusion of dropout channels +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In MATLAB PREP, dropout channels (i.e., channels that have intermittent periods +of flat signal) are detected on each iteration of the reference loop, but are +currently not factored into the full set of "bad" channels to be interpolated. +By contrast, PyPREP will detect and interpolate any bad-by-dropout channels +detected during robust referencing. diff --git a/docs/whats_new.rst b/docs/whats_new.rst index ed0d7e2d..e1a78421 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -41,12 +41,16 @@ Changelog - Changed RANSAC so that "bad by high-frequency noise" channels are retained when making channel predictions (provided they aren't flagged as bad by any other metric), matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`64`) - Added a new flag ``matlab_strict`` to :class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, :class:`~pyprep.NoisyChannels`, and :func:`~pyprep.ransac.find_bad_by_ransac` for optionally matching MATLAB PREP's internal math as closely as possible, overriding areas where PyPREP attempts to improve on the original, by `Austin Hurst`_ (:gh:`70`) - Added a ``matlab_strict`` method for high-pass trend removal, exactly matching MATLAB PREP's values if ``matlab_strict`` is enabled, by `Austin Hurst`_ (:gh:`71`) -- Added a window-wise implementaion of RANSAC and made it the default method, reducing the typical RAM demands of robust re-referencing considerably, by `Austin Hurst`_ (:gh:`66`) +- Added a window-wise implementation of RANSAC and made it the default method, reducing the typical RAM demands of robust re-referencing considerably, by `Austin Hurst`_ (:gh:`66`) - Added `max_chunk_size` parameter for specifying the maximum chunk size to use for channel-wise RANSAC, allowing more control over PyPREP RAM usage, by `Austin Hurst`_ (:gh:`66`) - Changed :class:`~pyprep.Reference` to exclude "bad-by-SNR" channels from initial average referencing, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`78`) - Changed :class:`~pyprep.Reference` to only flag "unusable" channels (bad by flat, NaNs, or low SNR) from the first pass of noisy detection for permanent exclusion from the reference signal, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`78`) - Added a framework for automated testing of PyPREP's components against their MATLAB PREP counterparts (using ``.mat`` and ``.set`` files generated with the `matprep_artifacts`_ script), helping verify that the two PREP implementations are numerically equivalent when `matlab_strict` is ``True``, by `Austin Hurst`_ (:gh:`79`) - Changed :class:`~pyprep.NoisyChannels` to reuse the same random state for each run of RANSAC when ``matlab_strict`` is ``True``, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`89`) +- Added a new argument `as_dict` for :meth:`~pyprep.NoisyChannels.get_bads`, allowing easier retrieval of flagged noisy channels by category, by `Austin Hurst`_ (:gh:`93`) +- Added a new argument `max_iterations` for :meth:`~pyprep.Reference.perform_reference` and :meth:`~pyprep.Reference.robust_reference`, allowing the maximum number of referencing iterations to be user-configurable, by `Austin Hurst`_ (:gh:`93`) +- Changed :meth:`~pyprep.Reference.robust_reference` to ignore bad-by-dropout channels during referencing if ``matlab_strict`` is ``True``, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`) +- Changed :meth:`~pyprep.Reference.robust_reference` to allow initial bad-by-SNR channels to be used for rereferencing interpolation if no longer bad following initial average reference, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`) .. _matprep_artifacts: https://github.com/a-hurst/matprep_artifacts diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index 21ee56f8..b4592a64 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -125,46 +125,61 @@ def _get_filtered_data(self): return EEG_filt - def get_bads(self, verbose=False): - """Get a list of all channels currently flagged as bad. + def get_bads(self, verbose=False, as_dict=False): + """Get the names of all channels currently flagged as bad. Note that this method does not perform any bad channel detection itself, and only reports channels already detected as bad by other methods. Parameters ---------- - verbose : bool + verbose : bool, optional If ``True``, a summary of the channels currently flagged as by bad per category is printed. Defaults to ``False``. + as_dict: bool, optional + If ``True``, this method will return a dict of the channels currently + flagged as bad by each individual bad channel type. If ``False``, this + method will return a list of all unique bad channels detected so far. + Defaults to ``False``. Returns ------- - bads : list - THe names of all bad channels detected by any method so far. + bads : list or dict + The names of all bad channels detected so far, either as a combined + list or a dict indicating the channels flagged bad by each type. """ bads = { - "n/a": self.bad_by_nan, - "flat": self.bad_by_flat, - "deviation": self.bad_by_deviation, - "hf noise": self.bad_by_hf_noise, - "correl": self.bad_by_correlation, - "SNR": self.bad_by_SNR, - "dropout": self.bad_by_dropout, - "RANSAC": self.bad_by_ransac, + "bad_by_nan": self.bad_by_nan, + "bad_by_flat": self.bad_by_flat, + "bad_by_deviation": self.bad_by_deviation, + "bad_by_hf_noise": self.bad_by_hf_noise, + "bad_by_correlation": self.bad_by_correlation, + "bad_by_SNR": self.bad_by_SNR, + "bad_by_dropout": self.bad_by_dropout, + "bad_by_ransac": self.bad_by_ransac, } all_bads = set() for bad_chs in bads.values(): all_bads.update(bad_chs) + name_map = {"nan": "NaN", "hf_noise": "HF noise", "ransac": "RANSAC"} if verbose: out = f"Found {len(all_bads)} uniquely bad channels:\n" for bad_type, bad_chs in bads.items(): + bad_type = bad_type.replace("bad_by_", "") + if bad_type in name_map.keys(): + bad_type = name_map[bad_type] out += f"\n{len(bad_chs)} by {bad_type}: {bad_chs}\n" print(out) - return list(all_bads) + if as_dict: + bads["bad_all"] = list(all_bads) + else: + bads = list(all_bads) + + return bads def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None): """Call all the functions to detect bad channels. diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index d3468ce7..7146c402 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -33,6 +33,9 @@ class PrepPipeline: For example, for 60Hz you may specify ``np.arange(60, sfreq / 2, 60)``. Specify an empty list to skip the line noise removal step. + - max_iterations : int, optional + - The maximum number of iterations of noisy channel removal to + perform during robust referencing. Defaults to ``4``. montage : mne.channels.DigMontage Digital montage of EEG data. ransac : bool, optional @@ -150,6 +153,8 @@ def __init__( self.prep_params["ref_chs"] = self.ch_names_eeg if self.prep_params["reref_chs"] == "eeg": self.prep_params["reref_chs"] = self.ch_names_eeg + if "max_iterations" not in prep_params.keys(): + self.prep_params["max_iterations"] = 4 self.sfreq = self.raw_eeg.info["sfreq"] self.ransac_settings = { "ransac": ransac, @@ -215,7 +220,7 @@ def fit(self): matlab_strict=self.matlab_strict, **self.ransac_settings, ) - reference.perform_reference() + reference.perform_reference(self.prep_params["max_iterations"]) self.raw_eeg = reference.raw self.noisy_channels_original = reference.noisy_channels_original self.noisy_channels_before_interpolation = ( diff --git a/pyprep/reference.py b/pyprep/reference.py index 8aa4b04d..646dba0e 100644 --- a/pyprep/reference.py +++ b/pyprep/reference.py @@ -93,9 +93,15 @@ def __init__( self._extra_info = {} self.matlab_strict = matlab_strict - def perform_reference(self): + def perform_reference(self, max_iterations=4): """Estimate the true signal mean and interpolate bad channels. + Parameters + ---------- + max_iterations : int, optional + The maximum number of iterations of noisy channel removal to perform + during robust referencing. Defaults to ``4``. + This function implements the functionality of the `performReference` function as part of the PREP pipeline on mne raw object. @@ -107,7 +113,7 @@ def perform_reference(self): """ # Phase 1: Estimate the true signal mean with robust referencing - self.robust_reference() + self.robust_reference(max_iterations) # If we interpolate the raw here we would be interpolating # more than what we later actually account for (in interpolated channels). dummy = self.raw.copy() @@ -134,17 +140,7 @@ def perform_reference(self): # Record Noisy channels and EEG before interpolation self.bad_before_interpolation = noisy_detector.get_bads(verbose=True) self.EEG_before_interpolation = self.EEG.copy() - self.noisy_channels_before_interpolation = { - "bad_by_nan": noisy_detector.bad_by_nan, - "bad_by_flat": noisy_detector.bad_by_flat, - "bad_by_deviation": noisy_detector.bad_by_deviation, - "bad_by_hf_noise": noisy_detector.bad_by_hf_noise, - "bad_by_correlation": noisy_detector.bad_by_correlation, - "bad_by_SNR": noisy_detector.bad_by_SNR, - "bad_by_dropout": noisy_detector.bad_by_dropout, - "bad_by_ransac": noisy_detector.bad_by_ransac, - "bad_all": noisy_detector.get_bads(), - } + self.noisy_channels_before_interpolation = noisy_detector.get_bads(as_dict=True) self._extra_info["interpolated"] = noisy_detector._extra_info bad_channels = _union(self.bad_before_interpolation, self.unusable_channels) @@ -170,27 +166,23 @@ def perform_reference(self): noisy_detector.find_all_bads(**self.ransac_settings) self.still_noisy_channels = noisy_detector.get_bads() self.raw.info["bads"] = self.still_noisy_channels - self.noisy_channels_after_interpolation = { - "bad_by_nan": noisy_detector.bad_by_nan, - "bad_by_flat": noisy_detector.bad_by_flat, - "bad_by_deviation": noisy_detector.bad_by_deviation, - "bad_by_hf_noise": noisy_detector.bad_by_hf_noise, - "bad_by_correlation": noisy_detector.bad_by_correlation, - "bad_by_SNR": noisy_detector.bad_by_SNR, - "bad_by_dropout": noisy_detector.bad_by_dropout, - "bad_by_ransac": noisy_detector.bad_by_ransac, - "bad_all": noisy_detector.get_bads(), - } + self.noisy_channels_after_interpolation = noisy_detector.get_bads(as_dict=True) self._extra_info["remaining_bad"] = noisy_detector._extra_info return self - def robust_reference(self): + def robust_reference(self, max_iterations=4): """Detect bad channels and estimate the robust reference signal. This function implements the functionality of the `robustReference` function as part of the PREP pipeline on mne raw object. + Parameters + ---------- + max_iterations : int, optional + The maximum number of iterations of noisy channel removal to perform + during robust referencing. Defaults to ``4``. + Returns ------- noisy_channels: dict @@ -213,17 +205,7 @@ def robust_reference(self): matlab_strict=self.matlab_strict, ) noisy_detector.find_all_bads(**self.ransac_settings) - self.noisy_channels_original = { - "bad_by_nan": noisy_detector.bad_by_nan, - "bad_by_flat": noisy_detector.bad_by_flat, - "bad_by_deviation": noisy_detector.bad_by_deviation, - "bad_by_hf_noise": noisy_detector.bad_by_hf_noise, - "bad_by_correlation": noisy_detector.bad_by_correlation, - "bad_by_SNR": noisy_detector.bad_by_SNR, - "bad_by_dropout": noisy_detector.bad_by_dropout, - "bad_by_ransac": noisy_detector.bad_by_ransac, - "bad_all": noisy_detector.get_bads(), - } + self.noisy_channels_original = noisy_detector.get_bads(as_dict=True) self._extra_info["initial_bad"] = noisy_detector._extra_info logger.info("Bad channels: {}".format(self.noisy_channels_original)) @@ -235,16 +217,16 @@ def robust_reference(self): reference_channels = _set_diff(self.reference_channels, self.unusable_channels) # Initialize channels to permanently flag as bad during referencing - self.noisy_channels = { + noisy = { "bad_by_nan": noisy_detector.bad_by_nan, "bad_by_flat": noisy_detector.bad_by_flat, "bad_by_deviation": [], "bad_by_hf_noise": [], "bad_by_correlation": [], - "bad_by_SNR": noisy_detector.bad_by_SNR, + "bad_by_SNR": [], "bad_by_dropout": [], "bad_by_ransac": [], - "bad_all": self.unusable_channels, + "bad_all": [], } # Get initial estimate of the reference by the specified method @@ -260,8 +242,7 @@ def robust_reference(self): # Remove reference from signal, iteratively interpolating bad channels raw_tmp = raw.copy() iterations = 0 - noisy_channels_old = [] - max_iteration_num = 4 + previous_bads = set() while True: raw_tmp._data = signal_tmp * 1e-6 @@ -272,51 +253,46 @@ def robust_reference(self): matlab_strict=self.matlab_strict, ) # Detrend applied at the beginning of the function. + + # Detect all currently bad channels noisy_detector.find_all_bads(**self.ransac_settings) - self.noisy_channels["bad_by_nan"] = _union( - self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan - ) - self.noisy_channels["bad_by_flat"] = _union( - self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat - ) - self.noisy_channels["bad_by_deviation"] = _union( - self.noisy_channels["bad_by_deviation"], noisy_detector.bad_by_deviation - ) - self.noisy_channels["bad_by_hf_noise"] = _union( - self.noisy_channels["bad_by_hf_noise"], noisy_detector.bad_by_hf_noise - ) - self.noisy_channels["bad_by_correlation"] = _union( - self.noisy_channels["bad_by_correlation"], - noisy_detector.bad_by_correlation, - ) - self.noisy_channels["bad_by_ransac"] = _union( - self.noisy_channels["bad_by_ransac"], noisy_detector.bad_by_ransac - ) - self.noisy_channels["bad_all"] = _union( - self.noisy_channels["bad_all"], noisy_detector.get_bads() - ) - logger.info("Bad channels: {}".format(self.noisy_channels)) + noisy_new = noisy_detector.get_bads(as_dict=True) + + # Specify bad channel types to ignore when updating noisy channels + # NOTE: MATLAB PREP ignores dropout channels, possibly by mistake? + # see: https://github.com/VisLab/EEG-Clean-Tools/issues/28 + ignore = ["bad_by_SNR", "bad_all"] + if self.matlab_strict: + ignore += ["bad_by_dropout"] + + # Update set of all noisy channels detected so far with any new ones + bad_chans = set() + for bad_type in noisy_new.keys(): + noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type]) + if bad_type not in ignore: + bad_chans.update(noisy[bad_type]) + noisy["bad_all"] = list(bad_chans) + logger.info("Bad channels: {}".format(noisy)) if ( iterations > 1 - and ( - not self.noisy_channels["bad_all"] - or set(self.noisy_channels["bad_all"]) == set(noisy_channels_old) - ) - or iterations > max_iteration_num + and (len(bad_chans) == 0 or bad_chans == previous_bads) + or iterations > max_iterations ): + logger.info("Robust reference done") + self.noisy_channels = noisy break - noisy_channels_old = self.noisy_channels["bad_all"].copy() + previous_bads = bad_chans.copy() - if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2: + if raw_tmp.info["nchan"] - len(bad_chans) < 2: raise ValueError( "RobustReference:TooManyBad " "Could not perform a robust reference -- not enough good channels" ) - if self.noisy_channels["bad_all"]: + if len(bad_chans) > 0: raw_tmp._data = signal * 1e-6 - raw_tmp.info["bads"] = self.noisy_channels["bad_all"] + raw_tmp.info["bads"] = list(bad_chans) raw_tmp.interpolate_bads() signal_tmp = raw_tmp.get_data() * 1e6 else: @@ -331,7 +307,6 @@ def robust_reference(self): iterations = iterations + 1 logger.info("Iterations: {}".format(iterations)) - logger.info("Robust reference done") return self.noisy_channels, self.reference_signal @staticmethod diff --git a/tests/test_reference.py b/tests/test_reference.py index 18f0c713..4cb4ed11 100644 --- a/tests/test_reference.py +++ b/tests/test_reference.py @@ -1,5 +1,5 @@ """Test Robust Reference.""" -import random +from unittest import mock import mne import numpy as np @@ -30,31 +30,39 @@ def test_basic_input(raw, montage): assert params["ref_chs"] == reference.reference_channels -@pytest.mark.usefixtures("raw", "montage") -def test_all_bad_input(raw, montage): - """Test robust reference when all reference channels are bad.""" - ch_names = raw.info["ch_names"] +@pytest.mark.usefixtures("raw_clean") +def test_clean_input(raw_clean): + """Test robust referencing with a clean input signal.""" + ch_names = raw_clean.info["ch_names"] + params = {"ref_chs": ch_names, "reref_chs": ch_names} - raw_tmp = raw.copy() - raw_tmp.set_montage(montage) - m, n = raw_tmp.get_data().shape + # Here we monkey-patch Reference to skip bad channel detection, ensuring + # a run with all clean channels is tested + with mock.patch("pyprep.NoisyChannels.find_all_bads", return_value=True): + reference = Reference(raw_clean, params, ransac=False) + reference.robust_reference() - # Randomly set some channels as bad - [nan_chn_idx, flat_chn_idx] = random.sample(set(np.arange(0, m)), 2) + assert len(reference.unusable_channels) == 0 + assert len(reference.noisy_channels_original["bad_all"]) == 0 + assert len(reference.noisy_channels["bad_all"]) == 0 - # Insert a nan value for a random channel - # nan_chn_lab = raw_tmp.ch_names[nan_chn_idx] - raw_tmp._data[nan_chn_idx, n - 1] = np.nan - # Insert one random flat channel - # flat_chn_lab = raw_tmp.ch_names[flat_chn_idx] - raw_tmp._data[flat_chn_idx, :] = np.ones_like(raw_tmp._data[1, :]) * 1e-6 +@pytest.mark.usefixtures("raw_clean") +def test_all_bad_input(raw_clean): + """Test robust reference when all reference channels are bad.""" + ch_names = raw_clean.info["ch_names"] + params = {"ref_chs": ch_names, "reref_chs": ch_names} - reference_channels = [ch_names[nan_chn_idx], ch_names[flat_chn_idx]] - params = {"ref_chs": reference_channels, "reref_chs": reference_channels} - reference = Reference(raw_tmp, params, ransac=False) - with pytest.raises(ValueError): - reference.robust_reference() + # Define a mock function to make all channels bad by deviation + def _bad_by_dev(self): + self.bad_by_deviation = self.ch_names_original.tolist() + + # Here we monkey-patch Reference to make all channels bad by deviation, allowing + # us to test the 'too-few-good-channels' exception + with mock.patch("pyprep.NoisyChannels.find_bad_by_deviation", new=_bad_by_dev): + reference = Reference(raw_clean, params, ransac=False) + with pytest.raises(ValueError): + reference.robust_reference() def test_remove_reference():