Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add matlab_strict API for matching MATLAB PREP's correlations and medians during RANSAC #70

Merged
merged 15 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions docs/matlab_differences.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
Deliberate Differences from the Original PREP
=============================================

Although PyPREP aims to be a faithful reimplementaion of the original MATLAB
version of PREP, there are a few places where PyPREP has deliberately chosen
to use different defaults than the MATLAB PREP.

To override these differerences, you can set the ``matlab_strict`` argument to
:class:`pyprep.prep_pipeline.PrepPipeline`, :class:`pyprep.reference.Reference`,
or :class:`pyprep.find_noisy_channels.NoisyChannels` as ``True`` to match the
original PREP's internal math.

Differences in RANSAC
=====================

During the "find-bad-by-RANSAC" step of noisy channel detection, PREP does the
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
follwing steps to identify channels that aren't well-predicted by the signals
of other channels:

1) Generates a bunch of random subsets of currently-good channels from the data
(50 samples by default, each containing 25% of the total EEG channels in the
dataset).
2) Uses the signals and spatial locations of those channels to predict what the
signals will be at the spatial locations of all the other channels, with each
random subset of channels generating a different prediction for each channel
(i.e., 50 predicted signals per channel by default).
3) For each channel, calculates the median predicted signal from the full set of
predictions.
4) Splits the full data into small non-overlapping windows (5 seconds by
default) and calculates the correlation between the median predicted signal
and the actual signal for each channel within each window.
5) Compares the correlations for each channel against a threshold value (0.75
by default), flags all windows that fall below that threshold as 'bad', and
calculates the proportions of 'bad' windows for each channel.
6) Flags all channels with an excessively high proportion of 'bad' windows
(minimum 0.4 by default) as being 'bad-by-RANSAC'.

With that in mind, here are the areas where PyPREP's defaults deliberately
differ from the original PREP implementation:

Calculation of median estimated signal
--------------------------------------

In MATLAB PREP, the median signal in step 3 is calculated by sorting the
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
different predictions for each EEG sample/channel from low to high and then
taking the value at the middle index (as calculated by
``int(n_ransac_samples / 2.0)``) for each.

Because this logic only returns the correct result for odd numbers of samples,
the current function will instead return the true median signal across
predictions unless strict MATLAB equivalence is requested.

Calculation of predicted vs. actual correlations in RANSAC
----------------------------------------------------------

In MATLAB PREP, RANSAC channel predictions are correlated with actual data
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
in step 4 using a non-standard method: essentialy, it uses the standard Pearson
correlation formula but without subtracting the channel means from each channel
before calculating sums of squares, i.e.,::

SSa = np.sum(a ** 2)
SSb = np.sum(b ** 2)
correlation = np.sum(a * b) / (np.sqrt(SSa) * np.sqrt(SSb))

Because EEG data is roughly mean-centered to begin with, this produces similar
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
values to normal Pearson correlation. However, to avoid making any assumptions
about the signal for any given channel/window, PyPREP defaults to normal
Pearson correlation unless strict MATLAB equivalence is requested.
2 changes: 2 additions & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Changelog
- Changed RANSAC chunking logic to reduce max memory use and prefer equal chunk sizes where possible, by `Austin Hurst`_ (:gh:`44`)
- Changed RANSAC's random channel sampling code to produce the same results as MATLAB PREP for the same random seed, additionally changing the default RANSAC sample size from 25% of all *good* channels (e.g. 15 for a 64-channel dataset with 4 bad channels) to 25% of *all* channels (e.g. 16 for the same dataset), by `Austin Hurst`_ (:gh:`62`)
- Changed RANSAC so that "bad by high-frequency noise" channels are retained when making channel predictions (provided they aren't flagged as bad by any other metric), matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`64`)
- Added a new flag ``matlab_strict`` to :class:`pyprep.prep_pipeline.PrepPipeline`, :class:`pyprep.reference.Reference`, :class:`pyprep.find_noisy_channels.NoisyChannels`, and :func:`pyprep.ransac.find_bad_by_ransac` for optionally matching MATLAB PREP's internal math as closely as possible, overriding areas where PyPREP attempts to improve on the original, by `Austin Hurst`_ (:gh:`70`)

Bug
~~~
Expand All @@ -53,6 +54,7 @@ API
- The following methods have been moved to a new module named :mod:`ransac` and are now private: :meth:`NoisyChannels.ransac_correlations`, :meth:`NoisyChannels.run_ransac <find_noisy_channels.NoisyChannels.run_ransac>`, and :meth:`NoisyChannels.get_ransac_pred <find_noisy_channels.NoisyChannels.get_ransac_pred>` methods, by `Yorguin Mantilla`_ (:gh:`51`)
- The permissible parameters for the following methods were removed and/or reordered: :meth:`NoisyChannels.ransac_correlations <find_noisy_channels.NoisyChannels.ransac_correlations>`, :meth:`NoisyChannels.run_ransac`, and :meth:`NoisyChannels.get_ransac_pred <find_noisy_channels.NoisyChannels.get_ransac_pred>` methods, by `Austin Hurst`_ and `Yorguin Mantilla`_ (:gh:`43`)


.. _changes_0_3_1:

Version 0.3.1
Expand Down
10 changes: 8 additions & 2 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,25 @@ class NoisyChannels:

"""

def __init__(self, raw, do_detrend=True, random_state=None):
def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False):
"""Initialize the class.

Parameters
----------
raw : mne.io.Raw
The MNE raw object.
do_detrend : bool
do_detrend : bool, optional
Whether or not to remove a trend from the data upon initializing the
`NoisyChannels` object. Defaults to True.
random_state : {int, None, np.random.RandomState}, optional
The random seed at which to initialize the class. If random_state
is an int, it will be used as a seed for RandomState.
If None, the seed will be obtained from the operating system
(see RandomState for details). Default is None.
matlab_strict : bool, optional
Whether or not PyPREP should strictly follow MATLAB PREP's internal
math, ignoring any improvements made in PyPREP over the original code.
Defaults to False.

"""
# Make sure that we got an MNE object
Expand All @@ -50,6 +54,7 @@ def __init__(self, raw, do_detrend=True, random_state=None):
self.raw_mne._data = removeTrend(
self.raw_mne.get_data(), sample_rate=self.sample_rate
)
self.matlab_strict = matlab_strict

self.EEGData = self.raw_mne.get_data(picks="eeg")
self.EEGData_beforeFilt = self.EEGData
Expand Down Expand Up @@ -475,6 +480,7 @@ def find_bad_by_ransac(
corr_window_secs,
channel_wise,
self.random_state,
self.matlab_strict,
)
self._extra_info['bad_by_ransac'] = {
'ransac_correlations': ch_correlations,
Expand Down
7 changes: 7 additions & 0 deletions pyprep/prep_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class PrepPipeline:
parameter, but use the "raw" and "prep_params" parameters instead.
If None is passed, the pyprep default settings for filtering are used
instead.
matlab_strict : bool, optional
Whether or not PyPREP should strictly follow MATLAB PREP's internal
math, ignoring any improvements made in PyPREP over the original code.
Defaults to False.

Attributes
----------
Expand Down Expand Up @@ -98,6 +102,7 @@ def __init__(
ransac=True,
random_state=None,
filter_kwargs=None,
matlab_strict=False,
):
"""Initialize PREP class."""
self.raw_eeg = raw.copy()
Expand Down Expand Up @@ -132,6 +137,7 @@ def __init__(
self.ransac = ransac
self.random_state = check_random_state(random_state)
self.filter_kwargs = filter_kwargs
self.matlab_strict = matlab_strict

@property
def raw(self):
Expand Down Expand Up @@ -184,6 +190,7 @@ def fit(self):
self.prep_params,
ransac=self.ransac,
random_state=self.random_state,
matlab_strict=self.matlab_strict
)
reference.perform_reference()
self.raw_eeg = reference.raw
Expand Down
35 changes: 26 additions & 9 deletions pyprep/ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from mne.channels.interpolation import _make_interpolation_matrix
from mne.utils import check_random_state

from pyprep.utils import split_list, verify_free_ram, _get_random_subset
from pyprep.utils import (
split_list, verify_free_ram, _get_random_subset, _mat_round, _correlate_arrays
)


def find_bad_by_ransac(
Expand All @@ -20,6 +22,7 @@ def find_bad_by_ransac(
corr_window_secs=5.0,
channel_wise=False,
random_state=None,
matlab_strict=False,
):
"""Detect channels that are not predicted well by other channels.

Expand Down Expand Up @@ -76,6 +79,10 @@ def find_bad_by_ransac(
RANSAC. If random_state is an int, it will be used as a seed for RandomState.
If ``None``, the seed will be obtained from the operating system
(see RandomState for details). Defaults to ``None``.
matlab_strict : bool, optional
Whether or not RANSAC should strictly follow MATLAB PREP's internal
math, ignoring any improvements made in PyPREP over the original code.
Defaults to False.

Returns
-------
Expand Down Expand Up @@ -187,6 +194,7 @@ def find_bad_by_ransac(
n_samples,
n,
w_correlation,
matlab_strict,
)
if chunk == channel_chunks[0]:
# If it gets here, it means it is the optimal
Expand Down Expand Up @@ -233,6 +241,7 @@ def _ransac_correlations(
n_samples,
n,
w_correlation,
matlab_strict,
):
"""Get correlations of channels to their RANSAC-predicted values.

Expand All @@ -259,6 +268,9 @@ def _ransac_correlations(
Number of frames/samples of each window.
w_correlation: int
Number of windows.
matlab_strict : bool
Whether or not RANSAC should strictly follow MATLAB PREP's internal
math, ignoring any improvements made in PyPREP over the original code.

Returns
-------
Expand All @@ -278,6 +290,7 @@ def _ransac_correlations(
good_chn_labs=good_chn_labs,
complete_chn_labs=complete_chn_labs,
data=data,
matlab_strict=matlab_strict,
)

# Correlate ransac prediction and eeg data
Expand All @@ -296,13 +309,7 @@ def _ransac_correlations(
for k in range(w_correlation):
data_portion = data_window[k, :, :]
pred_portion = pred_window[k, :, :]

R = np.corrcoef(data_portion, pred_portion)

# Take only correlations of data with pred
# and use diag to extract correlation of
# data_i with pred_i
R = np.diag(R[0 : len(chans_to_predict), len(chans_to_predict) :])
R = _correlate_arrays(data_portion, pred_portion, matlab_strict)
channel_correlations[k, :] = R

return channel_correlations
Expand All @@ -316,6 +323,7 @@ def _run_ransac(
good_chn_labs,
complete_chn_labs,
data,
matlab_strict,
):
"""Detect noisy channels apart from the ones described previously.

Expand All @@ -339,6 +347,9 @@ def _run_ransac(
labels of the channels in data in the same order
data : np.ndarray
2-D EEG data
matlab_strict : bool
Whether or not RANSAC should strictly follow MATLAB PREP's internal
math, ignoring any improvements made in PyPREP over the original code.

Returns
-------
Expand All @@ -365,7 +376,13 @@ def _run_ransac(
)

# Form median from all predictions
ransac_eeg = np.median(eeg_predictions, axis=-1, overwrite_input=True)
if matlab_strict:
# Match MATLAB's rounding logic (.5 always rounded up)
median_idx = int(_mat_round(n_samples / 2.0) - 1)
eeg_predictions.sort(axis=-1)
ransac_eeg = eeg_predictions[:, :, median_idx]
else:
ransac_eeg = np.median(eeg_predictions, axis=-1, overwrite_input=True)
return ransac_eeg


Expand Down
27 changes: 22 additions & 5 deletions pyprep/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class Reference:
an int, it will be used as a seed for RandomState.
If None, the seed will be obtained from the operating system
(see RandomState for details). Default is None.
matlab_strict : bool, optional
Whether or not PyPREP should strictly follow MATLAB PREP's internal
math, ignoring any improvements made in PyPREP over the original code.
Defaults to False.

References
----------
Expand All @@ -47,7 +51,9 @@ class Reference:

"""

def __init__(self, raw, params, ransac=True, random_state=None):
def __init__(
self, raw, params, ransac=True, random_state=None, matlab_strict=False
):
"""Initialize the class."""
self.raw = raw.copy()
self.ch_names = self.raw.ch_names
Expand All @@ -60,6 +66,7 @@ def __init__(self, raw, params, ransac=True, random_state=None):
self.ransac = ransac
self.random_state = check_random_state(random_state)
self._extra_info = {}
self.matlab_strict = matlab_strict

def perform_reference(self):
"""Estimate the true signal mean and interpolate bad channels.
Expand Down Expand Up @@ -94,7 +101,9 @@ def perform_reference(self):

# Phase 2: Find the bad channels and interpolate
self.raw._data = self.EEG * 1e-6
noisy_detector = NoisyChannels(self.raw, random_state=self.random_state)
noisy_detector = NoisyChannels(
self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict
)
noisy_detector.find_all_bads(ransac=self.ransac)

# Record Noisy channels and EEG before interpolation
Expand Down Expand Up @@ -130,7 +139,9 @@ def perform_reference(self):

# Still noisy channels after interpolation
self.interpolated_channels = bad_channels
noisy_detector = NoisyChannels(self.raw, random_state=self.random_state)
noisy_detector = NoisyChannels(
self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict
)
noisy_detector.find_all_bads(ransac=self.ransac)
self.still_noisy_channels = noisy_detector.get_bads()
self.raw.info["bads"] = self.still_noisy_channels
Expand Down Expand Up @@ -169,7 +180,10 @@ def robust_reference(self):

# Determine unusable channels and remove them from the reference channels
noisy_detector = NoisyChannels(
raw, do_detrend=False, random_state=self.random_state
raw,
do_detrend=False,
random_state=self.random_state,
matlab_strict=self.matlab_strict
)
noisy_detector.find_all_bads(ransac=self.ransac)
self.noisy_channels_original = {
Expand Down Expand Up @@ -222,7 +236,10 @@ def robust_reference(self):
while True:
raw_tmp._data = signal_tmp * 1e-6
noisy_detector = NoisyChannels(
raw_tmp, do_detrend=False, random_state=self.random_state
raw_tmp,
do_detrend=False,
random_state=self.random_state,
matlab_strict=self.matlab_strict
)
# Detrend applied at the beginning of the function.
noisy_detector.find_all_bads(ransac=self.ransac)
Expand Down
Loading