Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make noisy channel exclusion during Reference compatible with MATLAB PREP #93

Merged
merged 16 commits into from
Jun 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/matlab_differences.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,43 @@ roughly mean-centered. and will thus produce similar values to normal Pearson
correlation. However, to avoid making any assumptions about the signal for any
given channel / window, PyPREP defaults to normal Pearson correlation unless
strict MATLAB equivalence is requested.


Differences in Robust Referencing
---------------------------------

During the robust referencing part of the pipeline, PREP tries to estimate a
"clean" average reference signal for the dataset, excluding any channels
flagged as noisy from contaminating the reference. The robust referencing
process is performed using the following logic:

1) First, an initial pass of noisy channel detection is performed to identify
channels bad by NaN values, flat signal, or low SNR: the data is then
average-referenced excluding these channels. These channels are subsequently
marked as "unusable" and are excluded from any future average referencing.

2) Noisy channel detection is performed on a copy of the re-referenced signal,
and any newly detected bad channels are added to the full set of channels
to be excluded from the reference signal.

3) After noisy channel detection, all bad channels detected so far are
interpolated, and a new estimate of the robust average reference is
calculated using the mean signal of all good channels and all interpolated
bad channels (except those flagged as "unusable" during the first step).

4) A fresh copy of the re-referenced signal from Step 1 is re-referenced using
the new reference signal calculated in Step 3.

5) Steps 2 through 4 are repeated until either two iterations have passed and
no new noisy channels have been detected since the previous iteration, or
the maximum number of reference iterations has been exceeded (default: 4).


Exclusion of dropout channels
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In MATLAB PREP, dropout channels (i.e., channels that have intermittent periods
of flat signal) are detected on each iteration of the reference loop, but are
currently not factored into the full set of "bad" channels to be interpolated.
By contrast, PyPREP will detect and interpolate any bad-by-dropout channels
detected during robust referencing.
6 changes: 5 additions & 1 deletion docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@ Changelog
- Changed RANSAC so that "bad by high-frequency noise" channels are retained when making channel predictions (provided they aren't flagged as bad by any other metric), matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`64`)
- Added a new flag ``matlab_strict`` to :class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, :class:`~pyprep.NoisyChannels`, and :func:`~pyprep.ransac.find_bad_by_ransac` for optionally matching MATLAB PREP's internal math as closely as possible, overriding areas where PyPREP attempts to improve on the original, by `Austin Hurst`_ (:gh:`70`)
- Added a ``matlab_strict`` method for high-pass trend removal, exactly matching MATLAB PREP's values if ``matlab_strict`` is enabled, by `Austin Hurst`_ (:gh:`71`)
- Added a window-wise implementaion of RANSAC and made it the default method, reducing the typical RAM demands of robust re-referencing considerably, by `Austin Hurst`_ (:gh:`66`)
- Added a window-wise implementation of RANSAC and made it the default method, reducing the typical RAM demands of robust re-referencing considerably, by `Austin Hurst`_ (:gh:`66`)
- Added `max_chunk_size` parameter for specifying the maximum chunk size to use for channel-wise RANSAC, allowing more control over PyPREP RAM usage, by `Austin Hurst`_ (:gh:`66`)
- Changed :class:`~pyprep.Reference` to exclude "bad-by-SNR" channels from initial average referencing, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`78`)
- Changed :class:`~pyprep.Reference` to only flag "unusable" channels (bad by flat, NaNs, or low SNR) from the first pass of noisy detection for permanent exclusion from the reference signal, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`78`)
- Added a framework for automated testing of PyPREP's components against their MATLAB PREP counterparts (using ``.mat`` and ``.set`` files generated with the `matprep_artifacts`_ script), helping verify that the two PREP implementations are numerically equivalent when `matlab_strict` is ``True``, by `Austin Hurst`_ (:gh:`79`)
- Changed :class:`~pyprep.NoisyChannels` to reuse the same random state for each run of RANSAC when ``matlab_strict`` is ``True``, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`89`)
- Added a new argument `as_dict` for :meth:`~pyprep.NoisyChannels.get_bads`, allowing easier retrieval of flagged noisy channels by category, by `Austin Hurst`_ (:gh:`93`)
- Added a new argument `max_iterations` for :meth:`~pyprep.Reference.perform_reference` and :meth:`~pyprep.Reference.robust_reference`, allowing the maximum number of referencing iterations to be user-configurable, by `Austin Hurst`_ (:gh:`93`)
- Changed :meth:`~pyprep.Reference.robust_reference` to ignore bad-by-dropout channels during referencing if ``matlab_strict`` is ``True``, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`)
- Changed :meth:`~pyprep.Reference.robust_reference` to allow initial bad-by-SNR channels to be used for rereferencing interpolation if no longer bad following initial average reference, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`)

.. _matprep_artifacts: https://github.com/a-hurst/matprep_artifacts

Expand Down
43 changes: 29 additions & 14 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,46 +125,61 @@ def _get_filtered_data(self):

return EEG_filt

def get_bads(self, verbose=False):
"""Get a list of all channels currently flagged as bad.
def get_bads(self, verbose=False, as_dict=False):
"""Get the names of all channels currently flagged as bad.

Note that this method does not perform any bad channel detection itself,
and only reports channels already detected as bad by other methods.

Parameters
----------
verbose : bool
verbose : bool, optional
If ``True``, a summary of the channels currently flagged as by bad per
category is printed. Defaults to ``False``.
as_dict: bool, optional
If ``True``, this method will return a dict of the channels currently
flagged as bad by each individual bad channel type. If ``False``, this
method will return a list of all unique bad channels detected so far.
Defaults to ``False``.

Returns
-------
bads : list
THe names of all bad channels detected by any method so far.
bads : list or dict
The names of all bad channels detected so far, either as a combined
list or a dict indicating the channels flagged bad by each type.

"""
bads = {
"n/a": self.bad_by_nan,
"flat": self.bad_by_flat,
"deviation": self.bad_by_deviation,
"hf noise": self.bad_by_hf_noise,
"correl": self.bad_by_correlation,
"SNR": self.bad_by_SNR,
"dropout": self.bad_by_dropout,
"RANSAC": self.bad_by_ransac,
"bad_by_nan": self.bad_by_nan,
"bad_by_flat": self.bad_by_flat,
"bad_by_deviation": self.bad_by_deviation,
"bad_by_hf_noise": self.bad_by_hf_noise,
"bad_by_correlation": self.bad_by_correlation,
"bad_by_SNR": self.bad_by_SNR,
"bad_by_dropout": self.bad_by_dropout,
"bad_by_ransac": self.bad_by_ransac,
}

all_bads = set()
for bad_chs in bads.values():
all_bads.update(bad_chs)

name_map = {"nan": "NaN", "hf_noise": "HF noise", "ransac": "RANSAC"}
if verbose:
out = f"Found {len(all_bads)} uniquely bad channels:\n"
for bad_type, bad_chs in bads.items():
bad_type = bad_type.replace("bad_by_", "")
if bad_type in name_map.keys():
bad_type = name_map[bad_type]
out += f"\n{len(bad_chs)} by {bad_type}: {bad_chs}\n"
print(out)

return list(all_bads)
if as_dict:
bads["bad_all"] = list(all_bads)
else:
bads = list(all_bads)

return bads

def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None):
"""Call all the functions to detect bad channels.
Expand Down
7 changes: 6 additions & 1 deletion pyprep/prep_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class PrepPipeline:
For example, for 60Hz you may specify
``np.arange(60, sfreq / 2, 60)``. Specify an empty list to
skip the line noise removal step.
- max_iterations : int, optional
- The maximum number of iterations of noisy channel removal to
perform during robust referencing. Defaults to ``4``.
montage : mne.channels.DigMontage
Digital montage of EEG data.
ransac : bool, optional
Expand Down Expand Up @@ -150,6 +153,8 @@ def __init__(
self.prep_params["ref_chs"] = self.ch_names_eeg
if self.prep_params["reref_chs"] == "eeg":
self.prep_params["reref_chs"] = self.ch_names_eeg
if "max_iterations" not in prep_params.keys():
self.prep_params["max_iterations"] = 4
self.sfreq = self.raw_eeg.info["sfreq"]
self.ransac_settings = {
"ransac": ransac,
Expand Down Expand Up @@ -215,7 +220,7 @@ def fit(self):
matlab_strict=self.matlab_strict,
**self.ransac_settings,
)
reference.perform_reference()
reference.perform_reference(self.prep_params["max_iterations"])
self.raw_eeg = reference.raw
self.noisy_channels_original = reference.noisy_channels_original
self.noisy_channels_before_interpolation = (
Expand Down
123 changes: 49 additions & 74 deletions pyprep/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,15 @@ def __init__(
self._extra_info = {}
self.matlab_strict = matlab_strict

def perform_reference(self):
def perform_reference(self, max_iterations=4):
"""Estimate the true signal mean and interpolate bad channels.

Parameters
----------
max_iterations : int, optional
The maximum number of iterations of noisy channel removal to perform
during robust referencing. Defaults to ``4``.

This function implements the functionality of the `performReference` function
as part of the PREP pipeline on mne raw object.

Expand All @@ -107,7 +113,7 @@ def perform_reference(self):

"""
# Phase 1: Estimate the true signal mean with robust referencing
self.robust_reference()
self.robust_reference(max_iterations)
# If we interpolate the raw here we would be interpolating
# more than what we later actually account for (in interpolated channels).
dummy = self.raw.copy()
Expand All @@ -134,17 +140,7 @@ def perform_reference(self):
# Record Noisy channels and EEG before interpolation
self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
self.EEG_before_interpolation = self.EEG.copy()
self.noisy_channels_before_interpolation = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_before_interpolation = noisy_detector.get_bads(as_dict=True)
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
self._extra_info["interpolated"] = noisy_detector._extra_info

bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
Expand All @@ -170,27 +166,23 @@ def perform_reference(self):
noisy_detector.find_all_bads(**self.ransac_settings)
self.still_noisy_channels = noisy_detector.get_bads()
self.raw.info["bads"] = self.still_noisy_channels
self.noisy_channels_after_interpolation = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_after_interpolation = noisy_detector.get_bads(as_dict=True)
self._extra_info["remaining_bad"] = noisy_detector._extra_info

return self

def robust_reference(self):
def robust_reference(self, max_iterations=4):
"""Detect bad channels and estimate the robust reference signal.

This function implements the functionality of the `robustReference` function
as part of the PREP pipeline on mne raw object.

Parameters
----------
max_iterations : int, optional
The maximum number of iterations of noisy channel removal to perform
during robust referencing. Defaults to ``4``.

Returns
-------
noisy_channels: dict
Expand All @@ -213,17 +205,7 @@ def robust_reference(self):
matlab_strict=self.matlab_strict,
)
noisy_detector.find_all_bads(**self.ransac_settings)
self.noisy_channels_original = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_original = noisy_detector.get_bads(as_dict=True)
self._extra_info["initial_bad"] = noisy_detector._extra_info
logger.info("Bad channels: {}".format(self.noisy_channels_original))

Expand All @@ -235,16 +217,16 @@ def robust_reference(self):
reference_channels = _set_diff(self.reference_channels, self.unusable_channels)

# Initialize channels to permanently flag as bad during referencing
self.noisy_channels = {
noisy = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": [],
"bad_by_hf_noise": [],
"bad_by_correlation": [],
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_SNR": [],
"bad_by_dropout": [],
"bad_by_ransac": [],
"bad_all": self.unusable_channels,
"bad_all": [],
}

# Get initial estimate of the reference by the specified method
Expand All @@ -260,8 +242,7 @@ def robust_reference(self):
# Remove reference from signal, iteratively interpolating bad channels
raw_tmp = raw.copy()
iterations = 0
noisy_channels_old = []
max_iteration_num = 4
previous_bads = set()

while True:
raw_tmp._data = signal_tmp * 1e-6
Expand All @@ -272,51 +253,46 @@ def robust_reference(self):
matlab_strict=self.matlab_strict,
)
# Detrend applied at the beginning of the function.

# Detect all currently bad channels
noisy_detector.find_all_bads(**self.ransac_settings)
self.noisy_channels["bad_by_nan"] = _union(
self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan
)
self.noisy_channels["bad_by_flat"] = _union(
self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat
)
self.noisy_channels["bad_by_deviation"] = _union(
self.noisy_channels["bad_by_deviation"], noisy_detector.bad_by_deviation
)
self.noisy_channels["bad_by_hf_noise"] = _union(
self.noisy_channels["bad_by_hf_noise"], noisy_detector.bad_by_hf_noise
)
self.noisy_channels["bad_by_correlation"] = _union(
self.noisy_channels["bad_by_correlation"],
noisy_detector.bad_by_correlation,
)
self.noisy_channels["bad_by_ransac"] = _union(
self.noisy_channels["bad_by_ransac"], noisy_detector.bad_by_ransac
)
self.noisy_channels["bad_all"] = _union(
self.noisy_channels["bad_all"], noisy_detector.get_bads()
)
logger.info("Bad channels: {}".format(self.noisy_channels))
noisy_new = noisy_detector.get_bads(as_dict=True)

# Specify bad channel types to ignore when updating noisy channels
# NOTE: MATLAB PREP ignores dropout channels, possibly by mistake?
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
# see: https://github.com/VisLab/EEG-Clean-Tools/issues/28
ignore = ["bad_by_SNR", "bad_all"]
if self.matlab_strict:
ignore += ["bad_by_dropout"]

# Update set of all noisy channels detected so far with any new ones
bad_chans = set()
for bad_type in noisy_new.keys():
noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type])
if bad_type not in ignore:
bad_chans.update(noisy[bad_type])
noisy["bad_all"] = list(bad_chans)
logger.info("Bad channels: {}".format(noisy))

if (
iterations > 1
and (
not self.noisy_channels["bad_all"]
or set(self.noisy_channels["bad_all"]) == set(noisy_channels_old)
)
or iterations > max_iteration_num
and (len(bad_chans) == 0 or bad_chans == previous_bads)
or iterations > max_iterations
):
logger.info("Robust reference done")
self.noisy_channels = noisy
break
noisy_channels_old = self.noisy_channels["bad_all"].copy()
previous_bads = bad_chans.copy()

if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2:
if raw_tmp.info["nchan"] - len(bad_chans) < 2:
raise ValueError(
"RobustReference:TooManyBad "
"Could not perform a robust reference -- not enough good channels"
)

if self.noisy_channels["bad_all"]:
if len(bad_chans) > 0:
raw_tmp._data = signal * 1e-6
raw_tmp.info["bads"] = self.noisy_channels["bad_all"]
raw_tmp.info["bads"] = list(bad_chans)
raw_tmp.interpolate_bads()
signal_tmp = raw_tmp.get_data() * 1e6
else:
Expand All @@ -331,7 +307,6 @@ def robust_reference(self):
iterations = iterations + 1
logger.info("Iterations: {}".format(iterations))

logger.info("Robust reference done")
return self.noisy_channels, self.reference_signal

@staticmethod
Expand Down
Loading