Skip to content

Commit

Permalink
Make noisy channel exclusion during Reference compatible with MATLAB …
Browse files Browse the repository at this point in the history
…PREP (#93)

* Add "as_dict" option for get_bads

* Don't permanently exclude initial bad-by-SNRs

* Match PREP's noisy channel updating logic

* Minor variable name cleanup

* use as_dict throughout reference.py

* Fix 'ignore' logic

* Add back "bad_all" (whoops)

* Update matlab_differences.rst

* Add max_iterations args, link to dropout issue

* Fix diffs mistake, update whats_new.rst

* Add PrepPipeline API for max_iterations

* Improve test coverage for Reference

* Add whats_new entry for SNR changes

* Fix quotes to make black happy

* Improve Reference test coverage some more

* remove unused import
  • Loading branch information
a-hurst authored and sappelhoff committed Jul 7, 2021
1 parent dbe2062 commit 1e34468
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 110 deletions.
40 changes: 40 additions & 0 deletions docs/matlab_differences.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,43 @@ roughly mean-centered. and will thus produce similar values to normal Pearson
correlation. However, to avoid making any assumptions about the signal for any
given channel / window, PyPREP defaults to normal Pearson correlation unless
strict MATLAB equivalence is requested.


Differences in Robust Referencing
---------------------------------

During the robust referencing part of the pipeline, PREP tries to estimate a
"clean" average reference signal for the dataset, excluding any channels
flagged as noisy from contaminating the reference. The robust referencing
process is performed using the following logic:

1) First, an initial pass of noisy channel detection is performed to identify
channels bad by NaN values, flat signal, or low SNR: the data is then
average-referenced excluding these channels. These channels are subsequently
marked as "unusable" and are excluded from any future average referencing.

2) Noisy channel detection is performed on a copy of the re-referenced signal,
and any newly detected bad channels are added to the full set of channels
to be excluded from the reference signal.

3) After noisy channel detection, all bad channels detected so far are
interpolated, and a new estimate of the robust average reference is
calculated using the mean signal of all good channels and all interpolated
bad channels (except those flagged as "unusable" during the first step).

4) A fresh copy of the re-referenced signal from Step 1 is re-referenced using
the new reference signal calculated in Step 3.

5) Steps 2 through 4 are repeated until either two iterations have passed and
no new noisy channels have been detected since the previous iteration, or
the maximum number of reference iterations has been exceeded (default: 4).


Exclusion of dropout channels
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In MATLAB PREP, dropout channels (i.e., channels that have intermittent periods
of flat signal) are detected on each iteration of the reference loop, but are
currently not factored into the full set of "bad" channels to be interpolated.
By contrast, PyPREP will detect and interpolate any bad-by-dropout channels
detected during robust referencing.
4 changes: 4 additions & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ Changelog
- Changed :class:`~pyprep.Reference` to only flag "unusable" channels (bad by flat, NaNs, or low SNR) from the first pass of noisy detection for permanent exclusion from the reference signal, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`78`)
- Added a framework for automated testing of PyPREP's components against their MATLAB PREP counterparts (using ``.mat`` and ``.set`` files generated with the `matprep_artifacts`_ script), helping verify that the two PREP implementations are numerically equivalent when `matlab_strict` is ``True``, by `Austin Hurst`_ (:gh:`79`)
- Changed :class:`~pyprep.NoisyChannels` to reuse the same random state for each run of RANSAC when ``matlab_strict`` is ``True``, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`89`)
- Added a new argument `as_dict` for :meth:`~pyprep.NoisyChannels.get_bads`, allowing easier retrieval of flagged noisy channels by category, by `Austin Hurst`_ (:gh:`93`)
- Added a new argument `max_iterations` for :meth:`~pyprep.Reference.perform_reference` and :meth:`~pyprep.Reference.robust_reference`, allowing the maximum number of referencing iterations to be user-configurable, by `Austin Hurst`_ (:gh:`93`)
- Changed :meth:`~pyprep.Reference.robust_reference` to ignore bad-by-dropout channels during referencing if ``matlab_strict`` is ``True``, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`)
- Changed :meth:`~pyprep.Reference.robust_reference` to allow initial bad-by-SNR channels to be used for rereferencing interpolation if no longer bad following initial average reference, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`)

.. _matprep_artifacts: https://github.com/a-hurst/matprep_artifacts

Expand Down
43 changes: 29 additions & 14 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,46 +125,61 @@ def _get_filtered_data(self):

return EEG_filt

def get_bads(self, verbose=False):
"""Get a list of all channels currently flagged as bad.
def get_bads(self, verbose=False, as_dict=False):
"""Get the names of all channels currently flagged as bad.
Note that this method does not perform any bad channel detection itself,
and only reports channels already detected as bad by other methods.
Parameters
----------
verbose : bool
verbose : bool, optional
If ``True``, a summary of the channels currently flagged as by bad per
category is printed. Defaults to ``False``.
as_dict: bool, optional
If ``True``, this method will return a dict of the channels currently
flagged as bad by each individual bad channel type. If ``False``, this
method will return a list of all unique bad channels detected so far.
Defaults to ``False``.
Returns
-------
bads : list
THe names of all bad channels detected by any method so far.
bads : list or dict
The names of all bad channels detected so far, either as a combined
list or a dict indicating the channels flagged bad by each type.
"""
bads = {
"n/a": self.bad_by_nan,
"flat": self.bad_by_flat,
"deviation": self.bad_by_deviation,
"hf noise": self.bad_by_hf_noise,
"correl": self.bad_by_correlation,
"SNR": self.bad_by_SNR,
"dropout": self.bad_by_dropout,
"RANSAC": self.bad_by_ransac,
"bad_by_nan": self.bad_by_nan,
"bad_by_flat": self.bad_by_flat,
"bad_by_deviation": self.bad_by_deviation,
"bad_by_hf_noise": self.bad_by_hf_noise,
"bad_by_correlation": self.bad_by_correlation,
"bad_by_SNR": self.bad_by_SNR,
"bad_by_dropout": self.bad_by_dropout,
"bad_by_ransac": self.bad_by_ransac,
}

all_bads = set()
for bad_chs in bads.values():
all_bads.update(bad_chs)

name_map = {"nan": "NaN", "hf_noise": "HF noise", "ransac": "RANSAC"}
if verbose:
out = f"Found {len(all_bads)} uniquely bad channels:\n"
for bad_type, bad_chs in bads.items():
bad_type = bad_type.replace("bad_by_", "")
if bad_type in name_map.keys():
bad_type = name_map[bad_type]
out += f"\n{len(bad_chs)} by {bad_type}: {bad_chs}\n"
print(out)

return list(all_bads)
if as_dict:
bads["bad_all"] = list(all_bads)
else:
bads = list(all_bads)

return bads

def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None):
"""Call all the functions to detect bad channels.
Expand Down
7 changes: 6 additions & 1 deletion pyprep/prep_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class PrepPipeline:
For example, for 60Hz you may specify
``np.arange(60, sfreq / 2, 60)``. Specify an empty list to
skip the line noise removal step.
- max_iterations : int, optional
- The maximum number of iterations of noisy channel removal to
perform during robust referencing. Defaults to ``4``.
montage : mne.channels.DigMontage
Digital montage of EEG data.
ransac : bool, optional
Expand Down Expand Up @@ -150,6 +153,8 @@ def __init__(
self.prep_params["ref_chs"] = self.ch_names_eeg
if self.prep_params["reref_chs"] == "eeg":
self.prep_params["reref_chs"] = self.ch_names_eeg
if "max_iterations" not in prep_params.keys():
self.prep_params["max_iterations"] = 4
self.sfreq = self.raw_eeg.info["sfreq"]
self.ransac_settings = {
"ransac": ransac,
Expand Down Expand Up @@ -215,7 +220,7 @@ def fit(self):
matlab_strict=self.matlab_strict,
**self.ransac_settings,
)
reference.perform_reference()
reference.perform_reference(self.prep_params["max_iterations"])
self.raw_eeg = reference.raw
self.noisy_channels_original = reference.noisy_channels_original
self.noisy_channels_before_interpolation = (
Expand Down
123 changes: 49 additions & 74 deletions pyprep/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,15 @@ def __init__(
self._extra_info = {}
self.matlab_strict = matlab_strict

def perform_reference(self):
def perform_reference(self, max_iterations=4):
"""Estimate the true signal mean and interpolate bad channels.
Parameters
----------
max_iterations : int, optional
The maximum number of iterations of noisy channel removal to perform
during robust referencing. Defaults to ``4``.
This function implements the functionality of the `performReference` function
as part of the PREP pipeline on mne raw object.
Expand All @@ -107,7 +113,7 @@ def perform_reference(self):
"""
# Phase 1: Estimate the true signal mean with robust referencing
self.robust_reference()
self.robust_reference(max_iterations)
# If we interpolate the raw here we would be interpolating
# more than what we later actually account for (in interpolated channels).
dummy = self.raw.copy()
Expand All @@ -134,17 +140,7 @@ def perform_reference(self):
# Record Noisy channels and EEG before interpolation
self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
self.EEG_before_interpolation = self.EEG.copy()
self.noisy_channels_before_interpolation = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_before_interpolation = noisy_detector.get_bads(as_dict=True)
self._extra_info["interpolated"] = noisy_detector._extra_info

bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
Expand All @@ -170,27 +166,23 @@ def perform_reference(self):
noisy_detector.find_all_bads(**self.ransac_settings)
self.still_noisy_channels = noisy_detector.get_bads()
self.raw.info["bads"] = self.still_noisy_channels
self.noisy_channels_after_interpolation = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_after_interpolation = noisy_detector.get_bads(as_dict=True)
self._extra_info["remaining_bad"] = noisy_detector._extra_info

return self

def robust_reference(self):
def robust_reference(self, max_iterations=4):
"""Detect bad channels and estimate the robust reference signal.
This function implements the functionality of the `robustReference` function
as part of the PREP pipeline on mne raw object.
Parameters
----------
max_iterations : int, optional
The maximum number of iterations of noisy channel removal to perform
during robust referencing. Defaults to ``4``.
Returns
-------
noisy_channels: dict
Expand All @@ -213,17 +205,7 @@ def robust_reference(self):
matlab_strict=self.matlab_strict,
)
noisy_detector.find_all_bads(**self.ransac_settings)
self.noisy_channels_original = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_original = noisy_detector.get_bads(as_dict=True)
self._extra_info["initial_bad"] = noisy_detector._extra_info
logger.info("Bad channels: {}".format(self.noisy_channels_original))

Expand All @@ -235,16 +217,16 @@ def robust_reference(self):
reference_channels = _set_diff(self.reference_channels, self.unusable_channels)

# Initialize channels to permanently flag as bad during referencing
self.noisy_channels = {
noisy = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": [],
"bad_by_hf_noise": [],
"bad_by_correlation": [],
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_SNR": [],
"bad_by_dropout": [],
"bad_by_ransac": [],
"bad_all": self.unusable_channels,
"bad_all": [],
}

# Get initial estimate of the reference by the specified method
Expand All @@ -260,8 +242,7 @@ def robust_reference(self):
# Remove reference from signal, iteratively interpolating bad channels
raw_tmp = raw.copy()
iterations = 0
noisy_channels_old = []
max_iteration_num = 4
previous_bads = set()

while True:
raw_tmp._data = signal_tmp * 1e-6
Expand All @@ -272,51 +253,46 @@ def robust_reference(self):
matlab_strict=self.matlab_strict,
)
# Detrend applied at the beginning of the function.

# Detect all currently bad channels
noisy_detector.find_all_bads(**self.ransac_settings)
self.noisy_channels["bad_by_nan"] = _union(
self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan
)
self.noisy_channels["bad_by_flat"] = _union(
self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat
)
self.noisy_channels["bad_by_deviation"] = _union(
self.noisy_channels["bad_by_deviation"], noisy_detector.bad_by_deviation
)
self.noisy_channels["bad_by_hf_noise"] = _union(
self.noisy_channels["bad_by_hf_noise"], noisy_detector.bad_by_hf_noise
)
self.noisy_channels["bad_by_correlation"] = _union(
self.noisy_channels["bad_by_correlation"],
noisy_detector.bad_by_correlation,
)
self.noisy_channels["bad_by_ransac"] = _union(
self.noisy_channels["bad_by_ransac"], noisy_detector.bad_by_ransac
)
self.noisy_channels["bad_all"] = _union(
self.noisy_channels["bad_all"], noisy_detector.get_bads()
)
logger.info("Bad channels: {}".format(self.noisy_channels))
noisy_new = noisy_detector.get_bads(as_dict=True)

# Specify bad channel types to ignore when updating noisy channels
# NOTE: MATLAB PREP ignores dropout channels, possibly by mistake?
# see: https://github.com/VisLab/EEG-Clean-Tools/issues/28
ignore = ["bad_by_SNR", "bad_all"]
if self.matlab_strict:
ignore += ["bad_by_dropout"]

# Update set of all noisy channels detected so far with any new ones
bad_chans = set()
for bad_type in noisy_new.keys():
noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type])
if bad_type not in ignore:
bad_chans.update(noisy[bad_type])
noisy["bad_all"] = list(bad_chans)
logger.info("Bad channels: {}".format(noisy))

if (
iterations > 1
and (
not self.noisy_channels["bad_all"]
or set(self.noisy_channels["bad_all"]) == set(noisy_channels_old)
)
or iterations > max_iteration_num
and (len(bad_chans) == 0 or bad_chans == previous_bads)
or iterations > max_iterations
):
logger.info("Robust reference done")
self.noisy_channels = noisy
break
noisy_channels_old = self.noisy_channels["bad_all"].copy()
previous_bads = bad_chans.copy()

if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2:
if raw_tmp.info["nchan"] - len(bad_chans) < 2:
raise ValueError(
"RobustReference:TooManyBad "
"Could not perform a robust reference -- not enough good channels"
)

if self.noisy_channels["bad_all"]:
if len(bad_chans) > 0:
raw_tmp._data = signal * 1e-6
raw_tmp.info["bads"] = self.noisy_channels["bad_all"]
raw_tmp.info["bads"] = list(bad_chans)
raw_tmp.interpolate_bads()
signal_tmp = raw_tmp.get_data() * 1e6
else:
Expand All @@ -331,7 +307,6 @@ def robust_reference(self):
iterations = iterations + 1
logger.info("Iterations: {}".format(iterations))

logger.info("Robust reference done")
return self.noisy_channels, self.reference_signal

@staticmethod
Expand Down
Loading

0 comments on commit 1e34468

Please sign in to comment.