diff --git a/docs/api.rst b/docs/api.rst index 6cb317f1..9af08c81 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -8,8 +8,8 @@ API Documentation Here we list the Application Programming Interface (API) for pyprep. -The :class:`NoisyChannels` class --------------------------------- +The :class:`~pyprep.NoisyChannels` class +---------------------------------------- .. automodule:: pyprep :no-members: @@ -22,16 +22,24 @@ The :class:`NoisyChannels` class NoisyChannels -The :class:`PrepPipeline` class -------------------------------- +The :class:`~pyprep.Reference` class +------------------------------------ + +.. autosummary:: + :toctree: generated/ + + Reference + +The :class:`~pyprep.PrepPipeline` class +--------------------------------------- .. autosummary:: :toctree: generated/ PrepPipeline -The :mod:`ransac` module -=============================== +The :mod:`~pyprep.ransac` module +================================ .. automodule:: pyprep.ransac :no-members: diff --git a/docs/conf.py b/docs/conf.py index 3378a131..2fba666b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -91,6 +91,7 @@ ("Examples", "auto_examples/index"), ("API", "api"), ("What's new", "whats_new"), + ("Differences from PREP", "matlab_differences"), ("GitHub", "https://github.com/sappelhoff/pyprep", True), ], } @@ -100,7 +101,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "mne": ("https://mne.tools/dev", None), - "numpy": ("https://www.numpy.org/devdocs", None), + "numpy": ("https://numpy.org/devdocs", None), "scipy": ("https://scipy.github.io/devdocs", None), "matplotlib": ("https://matplotlib.org", None), } diff --git a/docs/matlab_differences.rst b/docs/matlab_differences.rst new file mode 100644 index 00000000..17af766f --- /dev/null +++ b/docs/matlab_differences.rst @@ -0,0 +1,105 @@ +:orphan: + +.. _matlab-diffs: + +Deliberate Differences from MATLAB PREP +======================================= + +Although PyPREP aims to be a faithful reimplementaion of the original MATLAB +version of PREP, there are a few places where PyPREP has deliberately chosen +to use different defaults than the MATLAB PREP. + +To override these differerences, you can set the ``matlab_strict`` argument to +:class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, or +:class:`~pyprep.NoisyChannels` as ``True`` to match the original PREP's +internal math. + +.. contents:: Table of Contents + :depth: 3 + + +Differences in RANSAC +--------------------- + +During the "find-bad-by-RANSAC" step of noisy channel detection (see +:func:`~pyprep.ransac.find_bad_by_ransac`), PREP does the follwing steps to +identify channels that aren't well-predicted by the signals of other channels: + +1) Generates a bunch of random subsets of currently-good channels from the data + (50 samples by default, each containing 25% of the total EEG channels in the + dataset). + +2) Uses the signals and spatial locations of those channels to predict what the + signals will be at the spatial locations of all the other channels, with each + random subset of channels generating a different prediction for each channel + (i.e., 50 predicted signals per channel by default). + +3) For each channel, calculates the median predicted signal from the full set of + predictions. + +4) Splits the full data into small non-overlapping windows (5 seconds by + default) and calculates the correlation between the median predicted signal + and the actual signal for each channel within each window. + +5) Compares the correlations for each channel against a threshold value (0.75 + by default), flags all windows that fall below that threshold as 'bad', and + calculates the proportions of 'bad' windows for each channel. + +6) Flags all channels with an excessively high proportion of 'bad' windows + (minimum 0.4 by default) as being 'bad-by-RANSAC'. + +With that in mind, here are the areas where PyPREP's defaults deliberately +differ from the original PREP implementation: + + +Calculation of median estimated signal +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In MATLAB PREP, the median signal in step 3 is calculated by sorting the +different predictions for each EEG sample/channel from low to high and then +taking the value at the middle index for each. The relevant lines of MATLAB +PREP's ``findNoisyChannels.m`` are reproduced below: + +.. code-block:: matlab + + function rX = calculateRansacWindow(XX, P, n, m, p) + YY = sort(reshape(XX*P, n, m, p),3); + YY = YY(:, :, round(end/2)); + rX = sum(XX.*YY)./(sqrt(sum(XX.^2)).*sqrt(sum(YY.^2))); + +The first line of the function generates the full set of predicted signals for +each RANSAC sample, and then sorts the predicted values for each channel / +timepoint from low to high. The second line calculates the index of the middle +value (``round(end/2)``) and then uses it to take the middle slice of the +sorted array to get the median predicted signal for each channel. + +Because this logic only returns the correct result for odd numbers of samples, +the current function will instead return the true median signal across +predictions unless strict MATLAB equivalence is requested. + + +Correlation of predicted vs. actual signals +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In MATLAB PREP, RANSAC channel predictions are correlated with actual data +in step 4 using a non-standard method: essentialy, it uses the standard Pearson +correlation formula but without subtracting the channel means from each channel +before calculating sums of squares. This is done in the last line of the +``calculateRansacWindow`` function reproduced above: + +.. code-block:: matlab + + rX = sum(XX.*YY)./(sqrt(sum(XX.^2)).*sqrt(sum(YY.^2))); + +For readablility, here's the same formula written in Python code:: + + SSxx = np.sum(xx ** 2) + SSyy = np.sum(yy ** 2) + rX = np.sum(xx * yy) / (np.sqrt(SSxx) * np.sqrt(SSyy)) + +Because the EEG data will have already been filtered to remove slow drifts in +baseline before RANSAC, the signals correlated by this method will already be +roughly mean-centered. and will thus produce similar values to normal Pearson +correlation. However, to avoid making any assumptions about the signal for any +given channel / window, PyPREP defaults to normal earson correlation unless +strict MATLAB equivalence is requested. diff --git a/docs/whats_new.rst b/docs/whats_new.rst index e348cf61..b6d80c62 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -7,11 +7,11 @@ People who contributed to this software across releases (in **alphabetical order * `Aamna Lawrence`_ * `Adam Li`_ +* `Austin Hurst`_ * `Christian Oreilly`_ * `Stefan Appelhoff`_ * `Victor Xiang`_ * `Yorguin Mantilla`_ -* `Austin Hurst`_ .. _whats_new: @@ -33,12 +33,13 @@ Current Changelog ~~~~~~~~~ -- Created a new module named :mod:`ransac` which contains :func:`find_bad_by_ransac `, a standalone function mirroring the previous ransac method from the :class:`NoisyChannels` class, by `Yorguin Mantilla`_ (:gh:`51`) -- Added two attributes :attr:`PrepPipeline.noisy_channels_before_interpolation ` and :attr:`PrepPipeline.noisy_channels_after_interpolation ` which have the detailed output of each noisy criteria, by `Yorguin Mantilla`_ (:gh:`45`) -- Added two keys to the :attr:`PrepPipeline.noisy_channels_original ` dictionary: ``bad_by_dropout`` and ``bad_by_SNR``, by `Yorguin Mantilla`_ (:gh:`45`) +- Created a new module named :mod:`pyprep.ransac` which contains :func:`find_bad_by_ransac `, a standalone function mirroring the previous ransac method from the :class:`NoisyChannels` class, by `Yorguin Mantilla`_ (:gh:`51`) +- Added two attributes :attr:`PrepPipeline.noisy_channels_before_interpolation ` and :attr:`PrepPipeline.noisy_channels_after_interpolation ` which have the detailed output of each noisy criteria, by `Yorguin Mantilla`_ (:gh:`45`) +- Added two keys to the :attr:`PrepPipeline.noisy_channels_original ` dictionary: ``bad_by_dropout`` and ``bad_by_SNR``, by `Yorguin Mantilla`_ (:gh:`45`) - Changed RANSAC chunking logic to reduce max memory use and prefer equal chunk sizes where possible, by `Austin Hurst`_ (:gh:`44`) - Changed RANSAC's random channel sampling code to produce the same results as MATLAB PREP for the same random seed, additionally changing the default RANSAC sample size from 25% of all *good* channels (e.g. 15 for a 64-channel dataset with 4 bad channels) to 25% of *all* channels (e.g. 16 for the same dataset), by `Austin Hurst`_ (:gh:`62`) - Changed RANSAC so that "bad by high-frequency noise" channels are retained when making channel predictions (provided they aren't flagged as bad by any other metric), matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`64`) +- Added a new flag ``matlab_strict`` to :class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, :class:`~pyprep.NoisyChannels`, and :func:`~pyprep.ransac.find_bad_by_ransac` for optionally matching MATLAB PREP's internal math as closely as possible, overriding areas where PyPREP attempts to improve on the original, by `Austin Hurst`_ (:gh:`70`) Bug ~~~ @@ -49,9 +50,10 @@ Bug API ~~~ -- The permissible parameters for the following methods were removed and/or reordered: :func:`ransac.ransac_correlations`, :func:`ransac.run_ransac`, and :func:`ransac.get_ransac_pred` methods, by `Yorguin Mantilla`_ (:gh:`51`) -- The following methods have been moved to a new module named :mod:`ransac` and are now private: :meth:`NoisyChannels.ransac_correlations`, :meth:`NoisyChannels.run_ransac `, and :meth:`NoisyChannels.get_ransac_pred ` methods, by `Yorguin Mantilla`_ (:gh:`51`) -- The permissible parameters for the following methods were removed and/or reordered: :meth:`NoisyChannels.ransac_correlations `, :meth:`NoisyChannels.run_ransac`, and :meth:`NoisyChannels.get_ransac_pred ` methods, by `Austin Hurst`_ and `Yorguin Mantilla`_ (:gh:`43`) +- The permissible parameters for the following methods were removed and/or reordered: `ransac._ransac_correlations`, `ransac._run_ransac`, and `ransac._get_ransac_pred` methods, by `Yorguin Mantilla`_ (:gh:`51`) +- The following methods have been moved to a new module named :mod:`~pyprep.ransac` and are now private: `NoisyChannels.ransac_correlations`, `NoisyChannels.run_ransac`, and `NoisyChannels.get_ransac_pred` methods, by `Yorguin Mantilla`_ (:gh:`51`) +- The permissible parameters for the following methods were removed and/or reordered: `NoisyChannels.ransac_correlations`, `NoisyChannels.run_ransac`, and `NoisyChannels.get_ransac_pred` methods, by `Austin Hurst`_ and `Yorguin Mantilla`_ (:gh:`43`) + .. _changes_0_3_1: @@ -60,22 +62,22 @@ Version 0.3.1 Changelog ~~~~~~~~~ -- It's now possible to pass keyword arguments to the notch filter inside :class:`PrepPipeline `; see the ``filter_kwargs`` parameter by `Yorguin Mantilla`_ (:gh:`40`) +- It's now possible to pass keyword arguments to the notch filter inside :class:`PrepPipeline `; see the ``filter_kwargs`` parameter by `Yorguin Mantilla`_ (:gh:`40`) - The default filter length for the spectrum_fit method will be '10s' to fix memory issues, by `Yorguin Mantilla`_ (:gh:`40`) -- Channel types are now available from a new ``ch_types_all`` attribute, and non-EEG channel names are now available from a new ``ch_names_non_eeg`` attribute from :class:`PrepPipeline `, by `Yorguin Mantilla`_ (:gh:`34`) -- Renaming of ``ch_names`` attribute of :class:`PrepPipeline ` to ``ch_names_all``, by `Yorguin Mantilla`_ (:gh:`34`) -- It's now possible to pass ``'eeg'`` to ``ref_chs`` and ``reref_chs`` keywords to the ``prep_params`` parameter of :class:`PrepPipeline ` to select only eeg channels for referencing, by `Yorguin Mantilla`_ (:gh:`34`) -- :class:`PrepPipeline ` will retain the non eeg channels through the ``raw`` attribute. The eeg-only and non-eeg parts will be in raw_eeg and raw_non_eeg respectively. See the ``raw`` attribute, by `Christian Oreilly`_ (:gh:`34`) +- Channel types are now available from a new ``ch_types_all`` attribute, and non-EEG channel names are now available from a new ``ch_names_non_eeg`` attribute from :class:`PrepPipeline `, by `Yorguin Mantilla`_ (:gh:`34`) +- Renaming of ``ch_names`` attribute of :class:`PrepPipeline ` to ``ch_names_all``, by `Yorguin Mantilla`_ (:gh:`34`) +- It's now possible to pass ``'eeg'`` to ``ref_chs`` and ``reref_chs`` keywords to the ``prep_params`` parameter of :class:`PrepPipeline ` to select only eeg channels for referencing, by `Yorguin Mantilla`_ (:gh:`34`) +- :class:`PrepPipeline ` will retain the non eeg channels through the ``raw`` attribute. The eeg-only and non-eeg parts will be in raw_eeg and raw_non_eeg respectively. See the ``raw`` attribute, by `Christian Oreilly`_ (:gh:`34`) - When a ransac call needs more memory than available, pyprep will now automatically switch to a slower but less memory-consuming version of ransac, by `Yorguin Mantilla`_ (:gh:`32`) -- It's now possible to pass an empty list for the ``line_freqs`` param in :class:`PrepPipeline ` to skip the line noise removal, by `Yorguin Mantilla`_ (:gh:`29`) -- The three main classes :class:`PrepPipeline `, :class:`NoisyChannels `, and :class:`Reference ` now have a ``random_state`` parameter to set a seed that gets passed on to all their internal methods and class calls, by `Stefan Appelhoff`_ (:gh:`31`) +- It's now possible to pass an empty list for the ``line_freqs`` param in :class:`PrepPipeline ` to skip the line noise removal, by `Yorguin Mantilla`_ (:gh:`29`) +- The three main classes :class:`~pyprep.PrepPipeline`, :class:`~pyprep.NoisyChannels`, and :class:`pyprep.Reference` now have a ``random_state`` parameter to set a seed that gets passed on to all their internal methods and class calls, by `Stefan Appelhoff`_ (:gh:`31`) Bug ~~~ -- Corrected inconsistency of :mod:`reference` module with the matlab version (:gh:`19`), by `Yorguin Mantilla`_ (:gh:`32`) -- Prevented an over detrending in :mod:`reference` module, by `Yorguin Mantilla`_ (:gh:`32`) +- Corrected inconsistency of :class:`~pyprep.Reference` with the matlab version (:gh:`19`), by `Yorguin Mantilla`_ (:gh:`32`) +- Prevented an over detrending in :class:`~pyprep.Reference`, by `Yorguin Mantilla`_ (:gh:`32`) API @@ -91,7 +93,7 @@ Version 0.3.0 Changelog ~~~~~~~~~ -- Include a boolean ``do_detrend`` in :meth:`Reference.robust_reference ` to indicate whether detrend should be done internally or not for the use with :mod:`find_noisy_channels` module, by `Yorguin Mantilla`_ (:gh:`9`) +- Include a boolean ``do_detrend`` in :meth:`~pyprep.Reference.robust_reference` to indicate whether detrend should be done internally or not for the use with :class:`~pyprep.NoisyChannels`, by `Yorguin Mantilla`_ (:gh:`9`) - Robust average referencing + tests, by `Victor Xiang`_ (:gh:`6`) - Removing trend in the EEG data by high pass filtering and local linear regression + tests, by `Aamna Lawrence`_ (:gh:`6`) - Finding noisy channels with comparable output to Matlab + tests-including test for ransac, by `Aamna Lawrence`_ (:gh:`6`) @@ -103,7 +105,7 @@ Changelog Bug ~~~ -- Prevent an undoing of the detrending in :mod:`find_noisy_channels` module, by `Yorguin Mantilla`_ (:gh:`9`) +- Prevent an undoing of the detrending in :class:`~pyprep.NoisyChannels`, by `Yorguin Mantilla`_ (:gh:`9`) API ~~~ diff --git a/examples/run_full_prep.py b/examples/run_full_prep.py index 18cca180..86694e4f 100644 --- a/examples/run_full_prep.py +++ b/examples/run_full_prep.py @@ -5,7 +5,7 @@ In this example we show how to run PREP with ``pyprep``. We also compare -:class:`prep_pipeline.PrepPipeline` with PREP's results in Matlab. +:class:`~pyprep.PrepPipeline` with PREP's results in Matlab. We use sample EEG data from Physionet EEG Motor Movement/Imagery Dataset: `https://physionet.org/content/eegmmidb/1.0.0/ `_ diff --git a/pyprep/__init__.py b/pyprep/__init__.py index e01d04c3..e365ca63 100644 --- a/pyprep/__init__.py +++ b/pyprep/__init__.py @@ -1,6 +1,7 @@ """initialize pyprep.""" import pyprep.ransac as ransac # noqa: F401 from pyprep.find_noisy_channels import NoisyChannels # noqa: F401 +from pyprep.reference import Reference # noqa: F401 from pyprep.prep_pipeline import PrepPipeline # noqa: F401 from ._version import get_versions diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index 7c0d17b4..d99aa380 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -24,21 +24,25 @@ class NoisyChannels: """ - def __init__(self, raw, do_detrend=True, random_state=None): + def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False): """Initialize the class. Parameters ---------- raw : mne.io.Raw The MNE raw object. - do_detrend : bool + do_detrend : bool, optional Whether or not to remove a trend from the data upon initializing the - `NoisyChannels` object. Defaults to True. + `NoisyChannels` object. Defaults to ``True``. random_state : {int, None, np.random.RandomState}, optional The random seed at which to initialize the class. If random_state is an int, it will be used as a seed for RandomState. - If None, the seed will be obtained from the operating system - (see RandomState for details). Default is None. + If ``None``, the seed will be obtained from the operating system + (see RandomState for details). Default is ``None``. + matlab_strict : bool, optional + Whether or not PyPREP should strictly follow MATLAB PREP's internal + math, ignoring any improvements made in PyPREP over the original code + (see :ref:`matlab-diffs` for more details). Defaults to ``False``. """ # Make sure that we got an MNE object @@ -50,6 +54,7 @@ def __init__(self, raw, do_detrend=True, random_state=None): self.raw_mne._data = removeTrend( self.raw_mne.get_data(), sample_rate=self.sample_rate ) + self.matlab_strict = matlab_strict self.EEGData = self.raw_mne.get_data(picks="eeg") self.EEGData_beforeFilt = self.EEGData @@ -407,7 +412,7 @@ def find_bad_by_ransac( ): """Detect channels that are not predicted well by other channels. - This method is a wrapper for the :func:`pyprep.ransac.find_bad_by_ransac` + This method is a wrapper for the :func:`ransac.find_bad_by_ransac` function. Here, a ransac approach (see [1]_, and a short discussion in [2]_) is @@ -475,6 +480,7 @@ def find_bad_by_ransac( corr_window_secs, channel_wise, self.random_state, + self.matlab_strict, ) self._extra_info['bad_by_ransac'] = { 'ransac_correlations': ch_correlations, diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index fd6c4888..ed0661b5 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -37,8 +37,7 @@ class PrepPipeline: Digital montage of EEG data. ransac : bool, optional Whether or not to use RANSAC for noisy channel detection in addition to - the other methods in :class:`pyprep.find_noisy_channels.NoisyChannels`. - Defaults to True. + the other methods in :class:`~pyprep.NoisyChannels`. Defaults to True. random_state : {int, None, np.random.RandomState}, optional The random seed at which to initialize the class. If random_state is an int, it will be used as a seed for RandomState. @@ -50,6 +49,10 @@ class PrepPipeline: parameter, but use the "raw" and "prep_params" parameters instead. If None is passed, the pyprep default settings for filtering are used instead. + matlab_strict : bool, optional + Whether or not PyPREP should strictly follow MATLAB PREP's internal + math, ignoring any improvements made in PyPREP over the original code + (see :ref:`matlab-diffs` for more details). Defaults to False. Attributes ---------- @@ -98,6 +101,7 @@ def __init__( ransac=True, random_state=None, filter_kwargs=None, + matlab_strict=False, ): """Initialize PREP class.""" self.raw_eeg = raw.copy() @@ -132,6 +136,7 @@ def __init__( self.ransac = ransac self.random_state = check_random_state(random_state) self.filter_kwargs = filter_kwargs + self.matlab_strict = matlab_strict @property def raw(self): @@ -184,6 +189,7 @@ def fit(self): self.prep_params, ransac=self.ransac, random_state=self.random_state, + matlab_strict=self.matlab_strict ) reference.perform_reference() self.raw_eeg = reference.raw diff --git a/pyprep/ransac.py b/pyprep/ransac.py index f2fd7c3c..b6e0af5f 100644 --- a/pyprep/ransac.py +++ b/pyprep/ransac.py @@ -4,7 +4,9 @@ from mne.channels.interpolation import _make_interpolation_matrix from mne.utils import check_random_state -from pyprep.utils import split_list, verify_free_ram, _get_random_subset +from pyprep.utils import ( + split_list, verify_free_ram, _get_random_subset, _mat_round, _correlate_arrays +) def find_bad_by_ransac( @@ -20,6 +22,7 @@ def find_bad_by_ransac( corr_window_secs=5.0, channel_wise=False, random_state=None, + matlab_strict=False, ): """Detect channels that are not predicted well by other channels. @@ -76,6 +79,10 @@ def find_bad_by_ransac( RANSAC. If random_state is an int, it will be used as a seed for RandomState. If ``None``, the seed will be obtained from the operating system (see RandomState for details). Defaults to ``None``. + matlab_strict : bool, optional + Whether or not RANSAC should strictly follow MATLAB PREP's internal + math, ignoring any improvements made in PyPREP over the original code + (see :ref:`matlab-diffs` for more details). Defaults to ``False``. Returns ------- @@ -187,6 +194,7 @@ def find_bad_by_ransac( n_samples, n, w_correlation, + matlab_strict, ) if chunk == channel_chunks[0]: # If it gets here, it means it is the optimal @@ -233,6 +241,7 @@ def _ransac_correlations( n_samples, n, w_correlation, + matlab_strict, ): """Get correlations of channels to their RANSAC-predicted values. @@ -259,6 +268,9 @@ def _ransac_correlations( Number of frames/samples of each window. w_correlation: int Number of windows. + matlab_strict : bool + Whether or not RANSAC should strictly follow MATLAB PREP's internal + math, ignoring any improvements made in PyPREP over the original code. Returns ------- @@ -278,6 +290,7 @@ def _ransac_correlations( good_chn_labs=good_chn_labs, complete_chn_labs=complete_chn_labs, data=data, + matlab_strict=matlab_strict, ) # Correlate ransac prediction and eeg data @@ -296,13 +309,7 @@ def _ransac_correlations( for k in range(w_correlation): data_portion = data_window[k, :, :] pred_portion = pred_window[k, :, :] - - R = np.corrcoef(data_portion, pred_portion) - - # Take only correlations of data with pred - # and use diag to extract correlation of - # data_i with pred_i - R = np.diag(R[0 : len(chans_to_predict), len(chans_to_predict) :]) + R = _correlate_arrays(data_portion, pred_portion, matlab_strict) channel_correlations[k, :] = R return channel_correlations @@ -316,6 +323,7 @@ def _run_ransac( good_chn_labs, complete_chn_labs, data, + matlab_strict, ): """Detect noisy channels apart from the ones described previously. @@ -339,6 +347,9 @@ def _run_ransac( labels of the channels in data in the same order data : np.ndarray 2-D EEG data + matlab_strict : bool + Whether or not RANSAC should strictly follow MATLAB PREP's internal + math, ignoring any improvements made in PyPREP over the original code. Returns ------- @@ -365,7 +376,13 @@ def _run_ransac( ) # Form median from all predictions - ransac_eeg = np.median(eeg_predictions, axis=-1, overwrite_input=True) + if matlab_strict: + # Match MATLAB's rounding logic (.5 always rounded up) + median_idx = int(_mat_round(n_samples / 2.0) - 1) + eeg_predictions.sort(axis=-1) + ransac_eeg = eeg_predictions[:, :, median_idx] + else: + ransac_eeg = np.median(eeg_predictions, axis=-1, overwrite_input=True) return ransac_eeg diff --git a/pyprep/reference.py b/pyprep/reference.py index e2fa834c..92df288a 100644 --- a/pyprep/reference.py +++ b/pyprep/reference.py @@ -31,13 +31,16 @@ class Reference: - ``reref_chs`` ransac : bool, optional Whether or not to use RANSAC for noisy channel detection in addition to - the other methods in :class:`pyprep.find_noisy_channels.NoisyChannels`. - Defaults to True. + the other methods in :class:`~pyprep.NoisyChannels`. Defaults to True. random_state : {int, None, np.random.RandomState}, optional The random seed at which to initialize the class. If random_state is an int, it will be used as a seed for RandomState. If None, the seed will be obtained from the operating system (see RandomState for details). Default is None. + matlab_strict : bool, optional + Whether or not PyPREP should strictly follow MATLAB PREP's internal + math, ignoring any improvements made in PyPREP over the original code. + Defaults to False. References ---------- @@ -47,7 +50,9 @@ class Reference: """ - def __init__(self, raw, params, ransac=True, random_state=None): + def __init__( + self, raw, params, ransac=True, random_state=None, matlab_strict=False + ): """Initialize the class.""" self.raw = raw.copy() self.ch_names = self.raw.ch_names @@ -60,6 +65,7 @@ def __init__(self, raw, params, ransac=True, random_state=None): self.ransac = ransac self.random_state = check_random_state(random_state) self._extra_info = {} + self.matlab_strict = matlab_strict def perform_reference(self): """Estimate the true signal mean and interpolate bad channels. @@ -94,7 +100,9 @@ def perform_reference(self): # Phase 2: Find the bad channels and interpolate self.raw._data = self.EEG * 1e-6 - noisy_detector = NoisyChannels(self.raw, random_state=self.random_state) + noisy_detector = NoisyChannels( + self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict + ) noisy_detector.find_all_bads(ransac=self.ransac) # Record Noisy channels and EEG before interpolation @@ -130,7 +138,9 @@ def perform_reference(self): # Still noisy channels after interpolation self.interpolated_channels = bad_channels - noisy_detector = NoisyChannels(self.raw, random_state=self.random_state) + noisy_detector = NoisyChannels( + self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict + ) noisy_detector.find_all_bads(ransac=self.ransac) self.still_noisy_channels = noisy_detector.get_bads() self.raw.info["bads"] = self.still_noisy_channels @@ -169,7 +179,10 @@ def robust_reference(self): # Determine unusable channels and remove them from the reference channels noisy_detector = NoisyChannels( - raw, do_detrend=False, random_state=self.random_state + raw, + do_detrend=False, + random_state=self.random_state, + matlab_strict=self.matlab_strict ) noisy_detector.find_all_bads(ransac=self.ransac) self.noisy_channels_original = { @@ -222,7 +235,10 @@ def robust_reference(self): while True: raw_tmp._data = signal_tmp * 1e-6 noisy_detector = NoisyChannels( - raw_tmp, do_detrend=False, random_state=self.random_state + raw_tmp, + do_detrend=False, + random_state=self.random_state, + matlab_strict=self.matlab_strict ) # Detrend applied at the beginning of the function. noisy_detector.find_all_bads(ransac=self.ransac) diff --git a/pyprep/utils.py b/pyprep/utils.py index a70d4976..987d9ed4 100644 --- a/pyprep/utils.py +++ b/pyprep/utils.py @@ -21,6 +21,28 @@ def _intersect(list1, list2): return list(set(list1).intersection(set(list2))) +def _mat_round(x): + """Round a number to the nearest whole number. + + Parameters + ---------- + x : float + The number to round. + + Returns + ------- + rounded : float + The input value, rounded to the nearest whole number. + + Notes + ----- + MATLAB rounds all numbers ending in .5 up to the nearest integer, whereas + Python (and Numpy) rounds them to the nearest even number. This function + mimics MATLAB's behaviour. + """ + return np.ceil(x) if x % 1 >= 0.5 else np.floor(x) + + def _mat_quantile(arr, q, axis=None): """Calculate the numeric value at quantile (`q`) for a given distribution. @@ -121,6 +143,56 @@ def _get_random_subset(x, size, rand_state): return sample +def _correlate_arrays(a, b, matlab_strict=False): + """Calculate correlations between two equally-sized 2-D arrays of EEG data. + + Both input arrays must be in the shape (channels, samples). + + Parameters + ---------- + a : np.ndarray + A 2-D array to correlate with `a`. + b : np.ndarray + A 2-D array to correlate with `b`. + matlab_strict : bool, optional + Whether or not correlations should be calculated identically to MATLAB + PREP (i.e., without mean subtraction) instead of by traditional Pearson + product-moment correlation (see Notes for details). Defaults to + ``False`` (Pearson correlation). + + Returns + ------- + correlations : np.ndarray + A one-dimensional array containing the correlations of the two input arrays + along the second axis. + + Notes + ----- + In MATLAB PREP, RANSAC channel predictions are correlated with actual data + using a non-standard method: essentialy, it uses the standard Pearson + correlation formula but without subtracting the channel means from each channel + before calculating sums of squares, i.e.,:: + + SSa = np.sum(a ** 2) + SSb = np.sum(b ** 2) + correlation = np.sum(a * b) / (np.sqrt(SSa) * np.sqrt(SSb)) + + Because EEG data is roughly mean-centered to begin with, this produces similar + values to normal Pearson correlation. However, to avoid making any assumptions + about the signal for any given channel/window, PyPREP defaults to normal + Pearson correlation unless strict MATLAB equivalence is requested. + + """ + if matlab_strict: + SSa = np.sum(a ** 2, axis=1) + SSb = np.sum(b ** 2, axis=1) + SSab = np.sum(a * b, axis=1) + return SSab / (np.sqrt(SSa) * np.sqrt(SSb)) + else: + n_chan = a.shape[0] + return np.diag(np.corrcoef(a, b)[:n_chan, n_chan:]) + + def filter_design(N_order, amp, freq): """Create FIR low-pass filter for EEG data using frequency sampling method. diff --git a/tests/test_find_noisy_channels.py b/tests/test_find_noisy_channels.py index b2c26d45..4e79558a 100644 --- a/tests/test_find_noisy_channels.py +++ b/tests/test_find_noisy_channels.py @@ -121,6 +121,16 @@ def test_findnoisychannels(raw, montage): bads = nd.bad_by_ransac assert bads == raw_tmp.ch_names[0:6] + # Test for finding bad channels by matlab_strict RANSAC + raw_tmp = raw.copy() + # Ransac identifies channels that go bad together and are highly correlated. + # Inserting highly correlated signal in channels 0 through 3 at 30 Hz + raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6 + nd = NoisyChannels(raw_tmp, random_state=rng, matlab_strict=True) + nd.find_bad_by_ransac() + bads = nd.bad_by_ransac + assert bads == raw_tmp.ch_names[0:6] + # Test for finding bad channels by channel-wise RANSAC raw_tmp = raw.copy() # Ransac identifies channels that go bad together and are highly correlated. diff --git a/tests/test_utils.py b/tests/test_utils.py index 80fc3854..a7711b88 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,20 @@ """Test various helper functions.""" import numpy as np -from pyprep.utils import _mat_quantile, _mat_iqr, _get_random_subset +from pyprep.utils import ( + _mat_round, _mat_quantile, _mat_iqr, _get_random_subset, _correlate_arrays +) + + +def test_mat_round(): + """Test the MATLAB-compatible rounding function.""" + # Test normal rounding behaviour + assert _mat_round(1.5) == 2 + assert _mat_round(0.4) == 0 + assert _mat_round(0.6) == 1 + + # Test MATLAB-specific rounding behaviour + assert _mat_round(0.5) == 1 def test_mat_quantile_iqr(): @@ -47,3 +60,35 @@ def test_get_random_subset(): expected_picks = [6, 47, 55, 31, 29, 44, 36, 15] actual_picks = _get_random_subset(chans, size=8, rand_state=rng) assert all(np.equal(expected_picks, actual_picks)) + + +def test_correlate_arrays(): + """Test MATLAB PREP-compatible array correlation function. + + MATLAB code used to generate the comparison results: + + .. code-block:: matlab + + % Generate test data + rng(435656); + a = rand(100, 3) - 0.5; + b = rand(100, 3) - 0.5; + + % Calculate correlations + correlations = sum(a.*b)./(sqrt(sum(a.^2)).*sqrt(sum(b.^2))); + + """ + # Generate test data + np.random.seed(435656) + a = np.random.rand(3, 100) - 0.5 + b = np.random.rand(3, 100) - 0.5 + + # Test regular Pearson correlation + corr_expected = np.asarray([-0.0898, 0.0340, -0.1068]) + corr_actual = _correlate_arrays(a, b) + assert all(np.isclose(corr_expected, corr_actual, atol=0.001)) + + # Test correlation equivalence with MATLAB PREP + corr_expected = np.asarray([-0.0898, 0.0327, -0.1140]) + corr_actual = _correlate_arrays(a, b, matlab_strict=True) + assert all(np.isclose(corr_expected, corr_actual, atol=0.001))