diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index 41120e71..0e790108 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -375,6 +375,7 @@ def find_bad_by_ransac( fraction_bad=0.4, corr_window_secs=5.0, channel_wise=False, + window_wise=False ): """Detect channels that are not predicted well by other channels. @@ -439,5 +440,6 @@ def find_bad_by_ransac( fraction_bad, corr_window_secs, channel_wise, + window_wise, self.random_state, ) diff --git a/pyprep/ransac.py b/pyprep/ransac.py index ef25068e..dd6f5391 100644 --- a/pyprep/ransac.py +++ b/pyprep/ransac.py @@ -4,7 +4,10 @@ from mne.channels.interpolation import _make_interpolation_matrix from mne.utils import check_random_state -from pyprep.utils import split_list, verify_free_ram, _get_random_subset +from pyprep.utils import ( + split_list, verify_free_ram, _get_random_subset, _mat_round, + _correlate_arrays +) def find_bad_by_ransac( @@ -20,6 +23,7 @@ def find_bad_by_ransac( fraction_bad=0.4, corr_window_secs=5.0, channel_wise=False, + window_wise=False, random_state=None, ): """Detect channels that are not predicted well by other channels. @@ -127,7 +131,7 @@ def find_bad_by_ransac( random_ch_picks.append(picks) # Correlation windows setup - correlation_frames = corr_window_secs * sample_rate + correlation_frames = int(corr_window_secs * sample_rate) correlation_window = np.arange(correlation_frames) n = correlation_window.shape[0] correlation_offsets = np.arange( @@ -159,7 +163,7 @@ def find_bad_by_ransac( if channel_wise: chunk_size = 1 - while mem_error: + while mem_error and not window_wise: try: channel_chunks = split_list(job, chunk_size) total_chunks = len(channel_chunks) @@ -198,9 +202,20 @@ def find_bad_by_ransac( "ested samples." ) - # Thresholding - thresholded_correlations = channel_correlations < corr_thresh - frac_bad_corr_windows = np.mean(thresholded_correlations, axis=0) + if window_wise: + # Get correlations between actual vs predicted signals for each RANSAC window + interp_mats = _make_interpolation_matrices(random_ch_picks, chn_pos_good) + channel_correlations = _ransac_correlations_alt( + data[good_idx, :], interp_mats, correlation_frames, w_correlation + ) + # Calculate fractions of bad RANSAC windows for each channel + thresholded_correlations = channel_correlations < corr_thresh + frac_bad_corr_windows = np.zeros(n_chans) + frac_bad_corr_windows[good_idx] = np.mean(thresholded_correlations, axis=0) + else: + # Calculate fractions of bad RANSAC windows for each channel + thresholded_correlations = channel_correlations < corr_thresh + frac_bad_corr_windows = np.mean(thresholded_correlations, axis=0) # find the corresponding channel names and return bad_ransac_channels_idx = np.argwhere(frac_bad_corr_windows > fraction_bad) @@ -396,3 +411,126 @@ def _get_ransac_pred( interpol_mat = _make_interpolation_matrix(reconstr_pos, chn_pos) ransac_pred = np.matmul(interpol_mat, data[reconstr_picks, :]) return ransac_pred + + +def _ransac_correlations_alt(good_data, interpolation_mats, win_size, win_count): + """Calculate correlations of channels with their RANSAC-predicted values. + + Parameters + ---------- + good_data : np.ndarray + A 2-D array containing the EEG signals from currently-good channels. + interpolation_mats : list of np.ndarray + A list of interpolation matrices, one for each RANSAC sample of channels. + win_size : int + Number of frames/samples of EEG data in each RANSAC correlation window. + win_count: int + Number of RANSAC correlation windows. + + Returns + ------- + channel_correlations : np.ndarray + Correlations of the given channels to their predicted values within each + RANSAC window. + """ + chn_count = good_data.shape[0] + channel_correlations = np.ones((win_count, chn_count)) + + for window in range(win_count): + + # Get the current window of EEG data + start = window * win_size + end = (window + 1) * win_size + actual = good_data[:, start:end] + + # Get the median RANSAC-predicted signal for each channel + predicted = _predict_median_signals(actual, interpolation_mats, True) + + # Calculate the actual vs predicted signal correlation for each channel + channel_correlations[window, :] = _correlate_arrays(actual, predicted) + + return channel_correlations + + +def _make_interpolation_matrices(random_ch_picks, chn_pos_good): + """Create an interpolation matrix for each RANSAC sample of channels. + + This function takes the spatial coordinates of random subsets of currently-good + channels and uses them to predict what the signal will be at the spatial + coordinates of all other currently-good channels. The results of this process are + returned as matrices that can be multiplied with EEG data to generate predicted + signals. + + Parameters + ---------- + random_ch_picks : list of list of int + A list containing multiple random subsets of currently-good channels. + chn_pos_good : np.ndarray + 3-D spatial coordinates of all currently-good channels. + + Returns + ------- + interpolation_mats : list of np.ndarray + A list of interpolation matrices, one for each random subset of channels. + Each matrix has the shape `[num_good_channels, num_good_channels]`, with the + number of good channels being inferred from the size of `ch_pos_good`. + + Notes + ----- + This function currently makes use of a private MNE function, + ``mne.channels.interpolation._make_interpolation_matrix``, to generate matrices. + """ + n_chans_good = chn_pos_good.shape[0] + + interpolation_mats = [] + for sample in random_ch_picks: + mat = np.zeros((n_chans_good, n_chans_good)) + subset_pos = chn_pos_good[sample, :] + mat[:, sample] = _make_interpolation_matrix(subset_pos, chn_pos_good) + interpolation_mats.append(mat) + + return interpolation_mats + + +def _predict_median_signals(window, interpolation_mats, matlab_strict=False): + """Calculate the median RANSAC-predicted signal for a given window of data. + + Parameters + ---------- + window : np.ndarray + A 2-D window of EEG data with the shape `[channels, samples]`. + interpolation_mats : list of np.ndarray + A set of channel interpolation matrices, one for each RANSAC sample of + channels. + matlab_strict : bool, optional + Whether MATLAB PREP's internal logic should be strictly followed (see Notes). + Defaults to False. + + Returns + ------- + predicted : np.ndarray + The median RANSAC-predicted EEG signal for the given window of data. + + Notes + ----- + In MATLAB PREP, the median signal is calculated by sorting the different + predictions for each EEG sample/channel from low to high and then taking the value + at the middle index (as calculated by `int(n_ransac_samples / 2.0)`) for each. + Because this logic only returns the correct result for odd numbers of samples, the + current function will instead return the true median signal across predictions + unless strict MATLAB equivalence is requested. + """ + ransac_samples = len(interpolation_mats) + merged_mats = np.concatenate(interpolation_mats, axis=0) + + predictions_per_sample = np.reshape( + np.matmul(merged_mats, window), + (ransac_samples, window.shape[0], window.shape[1]) + ) + + if matlab_strict: + # Match MATLAB's rounding logic (.5 always rounded up) + median_idx = int(_mat_round(ransac_samples / 2.0) - 1) + return np.sort(predictions_per_sample, axis=0)[median_idx, :, :] + else: + return np.median(predictions_per_sample, axis=0) diff --git a/tests/test_find_noisy_channels.py b/tests/test_find_noisy_channels.py index b2c26d45..5ca5190e 100644 --- a/tests/test_find_noisy_channels.py +++ b/tests/test_find_noisy_channels.py @@ -121,6 +121,16 @@ def test_findnoisychannels(raw, montage): bads = nd.bad_by_ransac assert bads == raw_tmp.ch_names[0:6] + # Test for finding bad channels by window-wise RANSAC + raw_tmp = raw.copy() + # Ransac identifies channels that go bad together and are highly correlated. + # Inserting highly correlated signal in channels 0 through 3 at 30 Hz + raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6 + nd = NoisyChannels(raw_tmp, random_state=rng) + nd.find_bad_by_ransac(window_wise=True) + bads = nd.bad_by_ransac + assert bads == raw_tmp.ch_names[0:6] + # Test for finding bad channels by channel-wise RANSAC raw_tmp = raw.copy() # Ransac identifies channels that go bad together and are highly correlated.