diff --git a/docs/matlab_differences.rst b/docs/matlab_differences.rst index 43228640..199c674d 100644 --- a/docs/matlab_differences.rst +++ b/docs/matlab_differences.rst @@ -9,10 +9,10 @@ Although PyPREP aims to be a faithful reimplementation of the original MATLAB version of PREP, there are a few places where PyPREP has deliberately chosen to use different defaults than the MATLAB PREP. -To override these differerences, you can set the ``matlab_strict`` argument to -:class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, or -:class:`~pyprep.NoisyChannels` as ``True`` to match the original PREP's -internal math. +To override these differerences, you can set the ``matlab_strict`` parameter +for :class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, or +:class:`~pyprep.NoisyChannels` to ``True`` in order to match the original +PREP's internal math. .. contents:: Table of Contents :depth: 3 @@ -39,8 +39,8 @@ Because the practical differences are small and MNE's filtering is fast and well-tested, PyPREP defaults to using :func:`mne.filter.filter_data` for high-pass trend removal. However, for exact numerical compatibility, PyPREP has a basic re-implementation of EEGLAB's ``pop_eegfiltnew`` in Python that -produces identical results to MATLAB PREP's ``removeTrend`` when -``matlab_strict`` is set to ``True``. +produces identical results to MATLAB PREP's ``removeTrend`` when the +``matlab_strict`` parameter is set to ``True``. Differences in RANSAC @@ -93,8 +93,8 @@ approach has the benefit of better randomness, but may also lead to more variability in PREP results between different seed values. More testing is required to determine which approach produces better results. -Note that to match MATLAB PREP exactly when ``matlab_strict`` is ``True``, the -random seed ``435656`` must be used. +Note that to match MATLAB PREP exactly when the ``matlab_strict`` parameter is +set to ``True``, the random seed ``435656`` must be used. Calculation of median estimated signal @@ -188,3 +188,35 @@ of flat signal) are detected on each iteration of the reference loop, but are currently not factored into the full set of "bad" channels to be interpolated. By contrast, PyPREP will detect and interpolate any bad-by-dropout channels detected during robust referencing. + + +Bad channel interpolation +^^^^^^^^^^^^^^^^^^^^^^^^^ + +MATLAB PREP uses EEGLAB's internal ``eeg_interp`` method of spherical spline +interpolation for interpolating identified bad channels during robust reference +estimation and (if enabled) immediately after the robust reference signal is +applied in order to remove any remaining detected bad channels once referencing +is complete. + +However, ``eeg_interp``'s method of spherical interpolations differs quite a bit +numerically from MNE's implementation as well as the interpolation method used +by MATLAB PREP for RANSAC predictions, both of which are numerically identical +and based directly on the formulas in Perrin et al. (1989) [1]_. ``eeg_interp`` +seems to use a modified variation of the Perrin et al. method, but diverges in +a number of ways that are not clearly documented or cited in the code. + +To keep with the more established method of spherical interpolation and stay +consistent with the interpolation code used in RANSAC, PyPREP defaults to using +MNE's :meth:`~mne.io.Raw.interpolate_bads` method for interpolation during and +following robust referencing. However, for full numeric equivalence with +MATLAB PREP, PyPREP will use a Python reimplementation of ``eeg_interp`` instead +when the ``matlab_strict`` parameter is set to ``True``. + + +References +---------- + +.. [1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989). + Spherical splines for scalp potential and current density mapping. + Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7. diff --git a/docs/whats_new.rst b/docs/whats_new.rst index e1a78421..85b24db8 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -51,6 +51,7 @@ Changelog - Added a new argument `max_iterations` for :meth:`~pyprep.Reference.perform_reference` and :meth:`~pyprep.Reference.robust_reference`, allowing the maximum number of referencing iterations to be user-configurable, by `Austin Hurst`_ (:gh:`93`) - Changed :meth:`~pyprep.Reference.robust_reference` to ignore bad-by-dropout channels during referencing if ``matlab_strict`` is ``True``, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`) - Changed :meth:`~pyprep.Reference.robust_reference` to allow initial bad-by-SNR channels to be used for rereferencing interpolation if no longer bad following initial average reference, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`) +- Added a ``matlab_strict`` method for bad channel interpolation, allowing for full numeric equivalence with MATLAB PREP's robust referencing, by `Austin Hurst`_ (:gh:`96`) .. _matprep_artifacts: https://github.com/a-hurst/matprep_artifacts diff --git a/pyprep/reference.py b/pyprep/reference.py index 646dba0e..ad6cbdd4 100644 --- a/pyprep/reference.py +++ b/pyprep/reference.py @@ -6,7 +6,7 @@ from pyprep.find_noisy_channels import NoisyChannels from pyprep.removeTrend import removeTrend -from pyprep.utils import _set_diff, _union +from pyprep.utils import _eeglab_interpolate_bads, _set_diff, _union logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -118,7 +118,10 @@ def perform_reference(self, max_iterations=4): # more than what we later actually account for (in interpolated channels). dummy = self.raw.copy() dummy.info["bads"] = self.noisy_channels["bad_all"] - dummy.interpolate_bads() + if self.matlab_strict: + _eeglab_interpolate_bads(dummy) + else: + dummy.interpolate_bads() self.reference_signal = ( np.nanmean(dummy.get_data(picks=self.reference_channels), axis=0) * 1e6 ) @@ -145,7 +148,10 @@ def perform_reference(self, max_iterations=4): bad_channels = _union(self.bad_before_interpolation, self.unusable_channels) self.raw.info["bads"] = bad_channels - self.raw.interpolate_bads() + if self.matlab_strict: + _eeglab_interpolate_bads(self.raw) + else: + self.raw.interpolate_bads() reference_correct = ( np.nanmean(self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6 ) @@ -293,7 +299,10 @@ def robust_reference(self, max_iterations=4): if len(bad_chans) > 0: raw_tmp._data = signal * 1e-6 raw_tmp.info["bads"] = list(bad_chans) - raw_tmp.interpolate_bads() + if self.matlab_strict: + _eeglab_interpolate_bads(raw_tmp) + else: + raw_tmp.interpolate_bads() signal_tmp = raw_tmp.get_data() * 1e6 else: signal_tmp = signal diff --git a/pyprep/utils.py b/pyprep/utils.py index 2144226b..a8460a1a 100644 --- a/pyprep/utils.py +++ b/pyprep/utils.py @@ -2,9 +2,13 @@ import math from cmath import sqrt +import mne import numpy as np import scipy.interpolate +from mne.surface import _normalize_vectors +from numpy.polynomial.legendre import legval from psutil import virtual_memory +from scipy import linalg from scipy.signal import firwin, lfilter, lfilter_zi @@ -235,6 +239,143 @@ def _eeglab_fir_filter(data, filt): return out +def _eeglab_calc_g(pos_from, pos_to, stiffness=4, num_lterms=7): + """Calculate spherical spline g function between points on a sphere. + + Parameters + ---------- + pos_from : np.ndarray of float, shape(n_good_sensors, 3) + The electrode positions to interpolate from. + pos_to : np.ndarray of float, shape(n_bad_sensors, 3) + The electrode positions to interpolate. + stiffness : float + Stiffness of the spline. + num_lterms : int + Number of Legendre terms to evaluate. + + Returns + ------- + G : np.ndarray of float, shape(n_channels, n_channels) + The G matrix. + + Notes + ----- + Produces identical output to the private ``computeg`` function in EEGLAB's + ``eeg_interp.m``. + + """ + # https://github.com/sccn/eeglab/blob/167dfc8/functions/popfunc/eeg_interp.m#L347 + + n_to = pos_to.shape[0] + n_from = pos_from.shape[0] + + # Calculate the Euclidian distances between the 'to' and 'from' electrodes + dxyz = [] + for i in range(0, 3): + d1 = np.repeat(pos_to[:, i], n_from).reshape((n_to, n_from)) + d2 = np.repeat(pos_from[:, i], n_to).reshape((n_from, n_to)).T + dxyz.append((d1 - d2) ** 2) + elec_dists = np.sqrt(sum(dxyz)) + + # Subtract all the Euclidian electrode distances from 1 (why?) + EI = np.ones([n_to, n_from]) - elec_dists + + # Calculate Legendre coefficients for the given degree and stiffness + factors = [0] + for n in range(1, num_lterms + 1): + f = (2 * n + 1) / (n ** stiffness * (n + 1) ** stiffness * 4 * np.pi) + factors.append(f) + + return legval(EI, factors) + + +def _eeglab_interpolate(data, pos_from, pos_to): + """Interpolate bad channels using EEGLAB's custom method. + + Parameters + ---------- + data : np.ndarray + A 2-D array containing signals from currently-good EEG channels with + which to interpolate signals for bad channels. + pos_from : np.ndarray of float, shape(n_good_sensors, 3) + The electrode positions to interpolate from. + pos_to : np.ndarray of float, shape(n_bad_sensors, 3) + The electrode positions to interpolate. + + Returns + ------- + interpolated : np.ndarray + The interpolated signals for all bad channels. + + Notes + ----- + Produces identical output to the private ``spheric_spline`` function in + EEGLAB's ``eeg_interp.m`` (with minor rounding errors). + + """ + # https://github.com/sccn/eeglab/blob/167dfc8/functions/popfunc/eeg_interp.m#L314 + + # Calculate G for distances between good electrodes + between goods & bads + G_from = _eeglab_calc_g(pos_from, pos_from) + G_to_from = _eeglab_calc_g(pos_from, pos_to) + + # Get average reference signal for all good channels and subtract from data + avg_ref = np.mean(data, axis=0) + data_tmp = data - avg_ref + + # Calculate interpolation matrix from electrode locations + pad_ones = np.ones((1, pos_from.shape[0])) + C_inv = linalg.pinv(np.vstack([G_from, pad_ones])) + interp_mat = np.matmul(G_to_from, C_inv[:, :-1]) + + # Interpolate bad channels and add average good reference to them + interpolated = np.matmul(interp_mat, data_tmp) + avg_ref + + return interpolated + + +def _eeglab_interpolate_bads(raw): + """Interpolate bad channels using EEGLAB's custom method. + + This method modifies the provided Raw object in place. + + Parameters + ---------- + raw : mne.io.Raw + An MNE Raw object for which channels marked as "bad" should be + interpolated. + + Notes + ----- + Produces identical results as EEGLAB's ``eeg_interp`` function when using + the default spheric spline method (with minor rounding errors). This method + appears to be loosely based on the same general Perrin et al. (1989) method + as MNE's interpolation, but there are several quirks with the implementation + that cause it to produce fairly different numbers. + + """ + # Get the indices of good and bad EEG channels + eeg_chans = mne.pick_types(raw.info, eeg=True, exclude=[]) + good_idx = mne.pick_types(raw.info, eeg=True, exclude="bads") + bad_idx = sorted(_set_diff(eeg_chans, good_idx)) + + # Get the spatial coordinates of the good and bad electrodes + elec_pos = raw._get_channel_positions(picks=eeg_chans) + pos_good = elec_pos[good_idx, :].copy() + pos_bad = elec_pos[bad_idx, :].copy() + _normalize_vectors(pos_good) + _normalize_vectors(pos_bad) + + # Interpolate bad channels + interp = _eeglab_interpolate(raw._data[good_idx, :], pos_good, pos_bad) + raw._data[bad_idx, :] = interp + + # Clear all bad EEG channels + eeg_bad_names = [raw.info["ch_names"][i] for i in bad_idx] + bads_non_eeg = _set_diff(raw.info["bads"], eeg_bad_names) + raw.info["bads"] = bads_non_eeg + + def _get_random_subset(x, size, rand_state): """Get a random subset of items from a list or array, without replacement. diff --git a/tests/test_matprep_compare.py b/tests/test_matprep_compare.py index 9b8b722d..997c7bf0 100644 --- a/tests/test_matprep_compare.py +++ b/tests/test_matprep_compare.py @@ -7,6 +7,7 @@ import scipy from pyprep.find_noisy_channels import NoisyChannels +from pyprep.reference import Reference from pyprep.removeTrend import removeTrend # Define some fixtures for things that will be used across multiple tests @@ -46,24 +47,23 @@ def matprep_artifacts(tmpdir_factory): @pytest.fixture(scope="session") -def matprep_noisy(matprep_artifacts): - """Import and preprocess artifact containing MATLAB PREP runtime info. - - This fixture only parses and retains data from the first pass of noisy - channel detection during re-referencing, since it's easiest to compare with - PyPREP. It also adds a new key to the imported struct, 'bads', which - contains the names of the channels flagged as bad by each detection - method (as opposed to just channel indices). +def matprep_info(matprep_artifacts): + """Get the runtime info data from MATLAB PREP for comparison with PyPREP. + This fixture helps convert the MATLAB PREP runtime info into a format that's + easier to compare with PyPREP, replacing channel indices with channel names + and renaming variables for consistency. """ # Read in and parse noisy channel info artifact from MATLAB PREP info_path = matprep_artifacts["matprep_info"] matprep_info = scipy.io.loadmat(info_path, simplify_cells=True)["prep_info"] - matprep_noisy_all = matprep_info["reference"] - matprep_noisy = matprep_noisy_all["noisyStatisticsOriginal"] - # Gather bad channel names from MatPREP, converting numbers to labels - ch_names = matprep_info["originalChannelLabels"] + # Extract all noisy info from MatPREP, replacing channel numbers with labels + noisy_types = { + "original": "noisyStatisticsOriginal", + "post_ref": "noisyStatisticsBeforeInterpolation", + "post_interp": "noisyStatistics", + } bad_types = { "badChannelsFromNaNs": "by_nan", "badChannelsFromNoData": "by_flat", @@ -75,14 +75,37 @@ def matprep_noisy(matprep_artifacts): "badChannelsFromRansac": "by_ransac", "all": "all", } - matprep_bads = {} - for bad_type, name in bad_types.items(): - bads_idx = matprep_noisy["noisyChannels"][bad_type] - bads_idx = [bads_idx] if isinstance(bads_idx, int) else bads_idx - matprep_bads[name] = [ch_names[i - 1] for i in bads_idx] - matprep_noisy["bads"] = matprep_bads + matprep_noisy_all = {} + ch_names = matprep_info["originalChannelLabels"] + for type_name, noisy_type in noisy_types.items(): + matprep_noisy = matprep_info["reference"][noisy_type] + matprep_bads = {} + for bad_type, name in bad_types.items(): + bads_idx = matprep_noisy["noisyChannels"][bad_type] + bads_idx = [bads_idx] if isinstance(bads_idx, int) else bads_idx + matprep_bads[name] = [ch_names[i - 1] for i in bads_idx] + matprep_noisy["bads"] = matprep_bads + matprep_noisy_all[type_name] = matprep_noisy + + out = { + "ch_names": ch_names, + "noisy": matprep_noisy_all, + "ref_signal": matprep_info["reference"]["referenceSignal"], + "cleanline": matprep_info["lineNoise"], + } + return out + - return matprep_noisy +@pytest.fixture(scope="session") +def matprep_noisy(matprep_info): + """Get MATLAB PREP runtime info regarding initial noisy channel detection. + + This fixture only provides data from the first pass of noisy channel + detection during re-referencing, since it's easiest to compare with + PyPREP. + + """ + return matprep_info["noisy"]["original"] @pytest.fixture(scope="session") @@ -91,7 +114,7 @@ def pyprep_noisy(matprep_artifacts): This fixture uses an artifact from MATLAB PREP of the CleanLined and detrended EEG signal right before MATLAB PREP runs its first iteration of - NoisyChannels during re-referncing. As such, any differences in test results + NoisyChannels during re-referencing. As such, any differences in test results will be due to actual differences in the noisy channel detection code rather than differences at an earlier stage of the pipeline. @@ -110,6 +133,41 @@ def pyprep_noisy(matprep_artifacts): return pyprep_noisy +@pytest.fixture(scope="session") +def matprep_reference(matprep_artifacts, matprep_info): + """Get robust re-referenced signal from MATLAB PREP.""" + # Import post-reference MATLAB PREP data + postref_path = matprep_artifacts["5_matprep_post_reference"] + matprep_postref = mne.io.read_raw_eeglab(postref_path, preload=True) + return matprep_postref + + +@pytest.fixture(scope="session") +def pyprep_reference(matprep_artifacts): + """Get the robust re-referenced signal for comparison with MATLAB PREP. + + This fixture uses an artifact from MATLAB PREP of the CleanLined EEG signal + right before MATLAB PREP calls ``performReference``. As such, the results + of these tests will not be affected by any differences in the CleanLine + implementations of MATLAB PREP and PyPREP. + + """ + # Import post-CleanLine MATLAB PREP data + setfile_path = matprep_artifacts["3_matprep_cleanline"] + matprep_set = mne.io.read_raw_eeglab(setfile_path, preload=True) + ch_names = matprep_set.info["ch_names"] + + # Run robust referencing on MATLAB data and extract internal noisy info + matprep_seed = 435656 + params = {"ref_chs": ch_names, "reref_chs": ch_names} + pyprep_reref = Reference( + matprep_set, params, random_state=matprep_seed, matlab_strict=True + ) + pyprep_reref.perform_reference() + + return pyprep_reref + + # Define MATLAB comparison tests for each main component of PyPREP @@ -291,3 +349,48 @@ def test_all_bads(self, pyprep_noisy, matprep_noisy): pyprep_bads_all = sorted(pyprep_noisy.get_bads()) matprep_bads_all = sorted(matprep_noisy["bads"]["all"]) assert pyprep_bads_all == matprep_bads_all + + +class TestCompareRobustReference(object): + """Compare the results of Reference to the equivalent MatPREP code. + + These comparisons use input data that's already had adaptive line noise + removal done to the signal, so any differences in results will be due to + differences in the robust referencing code itself. + + """ + + # TODO: once final interpolation is separated out in PyPREP, add test just + # for interpolation code + + def test_pre_interp_bads(self, pyprep_reference, matprep_info): + """Compare pre-interpolation bads between PyPREP and MatPREP.""" + matprep_bads = matprep_info["noisy"]["post_ref"]["bads"]["all"] + pyprep_bads = pyprep_reference.bad_before_interpolation + assert sorted(pyprep_bads) == sorted(matprep_bads) + + def test_remaining_bads(self, pyprep_reference, matprep_info): + """Compare post-interpolation bads between PyPREP and MatPREP.""" + matprep_bads = matprep_info["noisy"]["post_interp"]["bads"]["all"] + pyprep_bads = pyprep_reference.still_noisy_channels + assert sorted(pyprep_bads) == sorted(matprep_bads) + + def test_reference_signal(self, pyprep_reference, matprep_info): + """Compare the final reference signal between PyPREP and MatPREP.""" + TOL = 1e-4 # NOTE: Some diffs > 1e-5, maybe rounding error? + pyprep_ref = pyprep_reference.reference_signal_new + assert np.allclose(pyprep_ref, matprep_info["ref_signal"], atol=TOL) + + def test_full_signal(self, pyprep_reference, matprep_reference): + """Compare the full post-reference signal between PyPREP and MatPREP.""" + win_size = 500 # window of samples to check + + # Compare signals at start of recording + pyprep_start = pyprep_reference.raw._data[:, win_size] + matprep_start = matprep_reference._data[:, win_size] + assert np.allclose(pyprep_start, matprep_start) + + # Compare signals at end of recording + pyprep_end = pyprep_reference.raw._data[:, -win_size:] + matprep_end = matprep_reference._data[:, -win_size:] + assert np.allclose(pyprep_end, matprep_end)