From 9f5c7843c873347086e961fb3499039c408e8d5a Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Thu, 29 Apr 2021 14:23:14 -0300 Subject: [PATCH 01/10] Add pre-generation of interpolation matrices --- pyprep/ransac.py | 138 +++++++++++++++++++++++++---------------------- 1 file changed, 74 insertions(+), 64 deletions(-) diff --git a/pyprep/ransac.py b/pyprep/ransac.py index 5bc284f3..b0cc3b9e 100644 --- a/pyprep/ransac.py +++ b/pyprep/ransac.py @@ -142,6 +142,9 @@ def find_bad_by_ransac( picks = _get_random_subset(good_chans, n_pred_chns, rng) random_ch_picks.append(picks) + # Generate interpolation matrix for each RANSAC sample + interp_mats = _make_interpolation_matrices(random_ch_picks, chn_pos_good) + # Correlation windows setup correlation_frames = corr_window_secs * sample_rate correlation_window = np.arange(correlation_frames) @@ -182,12 +185,12 @@ def find_bad_by_ransac( total_chunks = len(channel_chunks) current = 1 for chunk in channel_chunks: + interp_mats_for_chunk = [mat[chunk, :] for mat in interp_mats] channel_correlations[:, good_idx[chunk]] = _ransac_correlations( chunk, random_ch_picks, - chn_pos_good, + interp_mats_for_chunk, data[good_idx, :], - n_samples, n, w_correlation, matlab_strict, @@ -226,12 +229,52 @@ def find_bad_by_ransac( return bad_by_ransac, channel_correlations +def _make_interpolation_matrices(random_ch_picks, chn_pos_good): + """Create an interpolation matrix for each RANSAC sample of channels. + + This function takes the spatial coordinates of random subsets of currently-good + channels and uses them to predict what the signal will be at the spatial + coordinates of all other currently-good channels. The results of this process are + returned as matrices that can be multiplied with EEG data to generate predicted + signals. + + Parameters + ---------- + random_ch_picks : list of list of int + A list containing multiple random subsets of currently-good channels. + chn_pos_good : np.ndarray + 3-D spatial coordinates of all currently-good channels. + + Returns + ------- + interpolation_mats : list of np.ndarray + A list of interpolation matrices, one for each random subset of channels. + Each matrix has the shape `[num_good_channels, num_good_channels]`, with the + number of good channels being inferred from the size of `ch_pos_good`. + + Notes + ----- + This function currently makes use of a private MNE function, + ``mne.channels.interpolation._make_interpolation_matrix``, to generate matrices. + + """ + n_chans_good = chn_pos_good.shape[0] + + interpolation_mats = [] + for sample in random_ch_picks: + mat = np.zeros((n_chans_good, n_chans_good)) + subset_pos = chn_pos_good[sample, :] + mat[:, sample] = _make_interpolation_matrix(subset_pos, chn_pos_good) + interpolation_mats.append(mat) + + return interpolation_mats + + def _ransac_correlations( chans_to_predict, random_ch_picks, - chn_pos_good, + interpolation_mats, data, - n_samples, n, w_correlation, matlab_strict, @@ -241,16 +284,15 @@ def _ransac_correlations( Parameters ---------- chans_to_predict : list of int - Indexes of the channels to predict as they appear in chn_pos. + Indices of the channels to predict (as they appear in `data`). random_ch_picks : list - each element is a list of indexes of the channels (as they appear - in chn_pos_good) to use for reconstruction in each of the samples. - chn_pos_good : np.ndarray - 3-D coordinates of all the channels not detected noisy so far + Each element is a list of indexes of the channels (as they appear + in `data`) to use for reconstruction in each of the samples. + interpolation_mats : list of np.ndarray + A set of channel interpolation matrices, one for each RANSAC sample of + channels. data : np.ndarray 2-D EEG data - n_samples : int - Number of samples used for computation of RANSAC. n : int Number of frames/samples of each window. w_correlation: int @@ -270,10 +312,9 @@ def _ransac_correlations( # Make the ransac predictions ransac_eeg = _run_ransac( - n_samples=n_samples, + chans_to_predict=chans_to_predict, random_ch_picks=random_ch_picks, - chn_pos=chn_pos_good[chans_to_predict, :], - chn_pos_good=chn_pos_good, + interpolation_mats=interpolation_mats, data=data, matlab_strict=matlab_strict, ) @@ -301,10 +342,9 @@ def _ransac_correlations( def _run_ransac( - n_samples, + chans_to_predict, random_ch_picks, - chn_pos, - chn_pos_good, + interpolation_mats, data, matlab_strict, ): @@ -315,15 +355,14 @@ def _run_ransac( Parameters ---------- - n_samples : int - number of interpolations from which a median will be computed + chans_to_predict : list of int + Indices of the channels to predict (as they appear in `data`). random_ch_picks : list - each element is a list of indexes of the channels (as they appear - in chn_pos_good) to use for reconstruction in each of the samples. - chn_pos : np.ndarray - 3-D coordinates of the electrode position - chn_pos_good : np.ndarray - 3-D coordinates of all the channels not detected noisy so far + Each element is a list of indexes of the channels (as they appear + in `data`) to use for reconstruction in each of the samples. + interpolation_mats : list of np.ndarray + A set of channel interpolation matrices, one for each RANSAC sample of + channels. data : np.ndarray 2-D EEG data matlab_strict : bool @@ -338,59 +377,30 @@ def _run_ransac( """ # n_chns, n_timepts = data.shape # 2 next lines should be equivalent but support single channel processing + ransac_samples = len(interpolation_mats) n_timepts = data.shape[1] - n_chns = chn_pos.shape[0] + n_chns = len(chans_to_predict) # Before running, make sure we have enough memory - verify_free_ram(data, n_samples, n_chns) + verify_free_ram(data, ransac_samples, n_chns) # Memory seems to be fine ... # Make the predictions - eeg_predictions = np.zeros((n_chns, n_timepts, n_samples)) - for sample in range(n_samples): - # Get the random channel selection for the current sample + eeg_predictions = np.zeros((n_chns, n_timepts, ransac_samples)) + for sample in range(ransac_samples): + # Get the random channels & interpolation matrix for the current sample reconstr_idx = random_ch_picks[sample] - eeg_predictions[..., sample] = _get_ransac_pred( - chn_pos, chn_pos_good, reconstr_idx, data - ) + interp_mat = interpolation_mats[sample][:, reconstr_idx] + # Predict the EEG signals for the current RANSAC sample / channel chunk + eeg_predictions[..., sample] = np.matmul(interp_mat, data[reconstr_idx, :]) # Form median from all predictions if matlab_strict: # Match MATLAB's rounding logic (.5 always rounded up) - median_idx = int(_mat_round(n_samples / 2.0) - 1) + median_idx = int(_mat_round(ransac_samples / 2.0) - 1) eeg_predictions.sort(axis=-1) ransac_eeg = eeg_predictions[:, :, median_idx] else: ransac_eeg = np.median(eeg_predictions, axis=-1, overwrite_input=True) return ransac_eeg - - -def _get_ransac_pred(chn_pos, chn_pos_good, reconstr_idx, data): - """Perform RANSAC prediction. - - Parameters - ---------- - chn_pos : np.ndarray - 3-D coordinates of the electrode position - chn_pos_good : np.ndarray - 3-D coordinates of all the channels not detected noisy so far - reconstr_idx : array_like - indexes of the channels in chn_pos_good to use for reconstruction - data : np.ndarray - 2-D EEG data - - Returns - ------- - ransac_pred : np.ndarray - Single RANSAC prediction - - """ - # Get positions - reconstr_pos = chn_pos_good[reconstr_idx, :] - - # Interpolate - interpol_mat = _make_interpolation_matrix(reconstr_pos, chn_pos) - ransac_pred = np.matmul(interpol_mat, data[reconstr_idx, :]) - - return ransac_pred From 29cb6bff4f91b04f35aba37099f3b32503e13b4d Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Thu, 29 Apr 2021 15:57:19 -0300 Subject: [PATCH 02/10] Internally implement window-wise RANSAC --- pyprep/find_noisy_channels.py | 1 + pyprep/ransac.py | 148 +++++++++++++++++++++++++++++----- 2 files changed, 127 insertions(+), 22 deletions(-) diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index 03d02ef0..d788fd27 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -479,6 +479,7 @@ def find_bad_by_ransac( fraction_bad, corr_window_secs, channel_wise, + False, self.random_state, self.matlab_strict, ) diff --git a/pyprep/ransac.py b/pyprep/ransac.py index b0cc3b9e..961414ad 100644 --- a/pyprep/ransac.py +++ b/pyprep/ransac.py @@ -21,6 +21,7 @@ def find_bad_by_ransac( fraction_bad=0.4, corr_window_secs=5.0, channel_wise=False, + window_wise=False, random_state=None, matlab_strict=False, ): @@ -131,7 +132,11 @@ def find_bad_by_ransac( # Before running, make sure we have enough memory when using the # smallest possible chunk size - verify_free_ram(data, n_samples, 1) + if window_wise: + window_size = int(sample_rate * corr_window_secs) + verify_free_ram(data[:, :window_size], n_samples, n_chans_good) + else: + verify_free_ram(data, n_samples, 1) # Generate random channel picks for each RANSAC sample random_ch_picks = [] @@ -145,25 +150,31 @@ def find_bad_by_ransac( # Generate interpolation matrix for each RANSAC sample interp_mats = _make_interpolation_matrices(random_ch_picks, chn_pos_good) - # Correlation windows setup + # Calculate the size (in frames) and count of correlation windows correlation_frames = corr_window_secs * sample_rate - correlation_window = np.arange(correlation_frames) - n = correlation_window.shape[0] signal_frames = data.shape[1] correlation_offsets = np.arange( 0, (signal_frames - correlation_frames), correlation_frames ) - w_correlation = correlation_offsets.shape[0] + win_size = int(correlation_frames) + win_count = correlation_offsets.shape[0] - # Preallocate + # Preallocate RANSAC correlation matrix n_chans_complete = len(complete_chn_labs) - channel_correlations = np.ones((w_correlation, n_chans_complete)) + channel_correlations = np.ones((win_count, n_chans_complete)) # Notice self.EEGData.shape[0] = self.n_chans_new # Is now data.shape[0] = n_chans_complete # They came from the same drop of channels print("Executing RANSAC\nThis may take a while, so be patient...") + # If enabled, run window-wise RANSAC + if window_wise: + # Get correlations between actual vs predicted signals for each RANSAC window + channel_correlations[:, good_idx] = _ransac_by_window( + data[good_idx, :], interp_mats, win_size, win_count, matlab_strict + ) + # Calculate smallest chunk size for each possible chunk count chunk_sizes = [] chunk_count = 0 @@ -173,26 +184,25 @@ def find_bad_by_ransac( chunk_count = n_chunks chunk_sizes.append(i) - chunk_size = chunk_sizes.pop() + chunk_size = 1 if channel_wise else chunk_sizes.pop() mem_error = True job = list(range(n_chans_good)) - if channel_wise: - chunk_size = 1 - while mem_error: + # If not using window-wise RANSAC, do channel-wise RANSAC + while mem_error and not window_wise: try: channel_chunks = split_list(job, chunk_size) total_chunks = len(channel_chunks) current = 1 for chunk in channel_chunks: interp_mats_for_chunk = [mat[chunk, :] for mat in interp_mats] - channel_correlations[:, good_idx[chunk]] = _ransac_correlations( + channel_correlations[:, good_idx[chunk]] = _ransac_by_channel( chunk, random_ch_picks, interp_mats_for_chunk, data[good_idx, :], - n, - w_correlation, + win_size, + win_count, matlab_strict, ) if chunk == channel_chunks[0]: @@ -270,7 +280,98 @@ def _make_interpolation_matrices(random_ch_picks, chn_pos_good): return interpolation_mats -def _ransac_correlations( +def _ransac_by_window(data, interpolation_mats, win_size, win_count, matlab_strict): + """Calculate correlations of channels with their RANSAC-predicted values. + + This function calculates RANSAC correlations for each RANSAC window + individually, requiring RAM equivalent to [channels * sample rate * seconds + per RANSAC window] to run. Generally, this method will use less RAM than + :func:`_ransac_by_channel`, with the exception of short recordings with high + electrode counts. + + Parameters + ---------- + data : np.ndarray + A 2-D array containing the EEG signals from currently-good channels. + interpolation_mats : list of np.ndarray + A list of interpolation matrices, one for each RANSAC sample of channels. + win_size : int + Number of frames/samples of EEG data in each RANSAC correlation window. + win_count: int + Number of RANSAC correlation windows. + + Returns + ------- + channel_correlations : np.ndarray + Correlations of the given channels to their predicted values within each + RANSAC window. + """ + ch_count = data.shape[0] + ch_correlations = np.ones((win_count, ch_count)) + + for window in range(win_count): + + # Get the current window of EEG data + start = window * win_size + end = (window + 1) * win_size + actual = data[:, start:end] + + # Get the median RANSAC-predicted signal for each channel + predicted = _predict_median_signals(actual, interpolation_mats, matlab_strict) + + # Calculate the actual vs predicted signal correlation for each channel + ch_correlations[window, :] = _correlate_arrays(actual, predicted, matlab_strict) + + return ch_correlations + + +def _predict_median_signals(window, interpolation_mats, matlab_strict=False): + """Calculate the median RANSAC-predicted signal for a given window of data. + + Parameters + ---------- + window : np.ndarray + A 2-D window of EEG data with the shape `[channels, samples]`. + interpolation_mats : list of np.ndarray + A set of channel interpolation matrices, one for each RANSAC sample of + channels. + matlab_strict : bool, optional + Whether MATLAB PREP's internal logic should be strictly followed (see Notes). + Defaults to False. + + Returns + ------- + predicted : np.ndarray + The median RANSAC-predicted EEG signal for the given window of data. + + Notes + ----- + In MATLAB PREP, the median signal is calculated by sorting the different + predictions for each EEG sample/channel from low to high and then taking the value + at the middle index (as calculated by `int(n_ransac_samples / 2.0)`) for each. + Because this logic only returns the correct result for odd numbers of samples, the + current function will instead return the true median signal across predictions + unless strict MATLAB equivalence is requested. + + """ + ransac_samples = len(interpolation_mats) + merged_mats = np.concatenate(interpolation_mats, axis=0) + + predictions_per_sample = np.reshape( + np.matmul(merged_mats, window), + (ransac_samples, window.shape[0], window.shape[1]) + ) + + if matlab_strict: + # Match MATLAB's rounding logic (.5 always rounded up) + median_idx = int(_mat_round(ransac_samples / 2.0) - 1) + predictions_per_sample.sort(axis=0) + return predictions_per_sample[median_idx, :, :] + else: + return np.median(predictions_per_sample, axis=0) + + +def _ransac_by_channel( chans_to_predict, random_ch_picks, interpolation_mats, @@ -279,7 +380,13 @@ def _ransac_correlations( w_correlation, matlab_strict, ): - """Get correlations of channels to their RANSAC-predicted values. + """Calculate correlations of channels with their RANSAC-predicted values. + + This function calculates RANSAC correlations on one (or more) full channels + at once, requiring RAM equivalent to [channels per chunk * sample rate * + length of recording in seconds] to run. Generally, this method will use + more RAM than :func:`_ransac_by_window`, but may be faster for systems with + large amounts of RAM. Parameters ---------- @@ -311,7 +418,7 @@ def _ransac_correlations( channel_correlations = np.ones((w_correlation, len(chans_to_predict))) # Make the ransac predictions - ransac_eeg = _run_ransac( + ransac_eeg = _predict_median_signals_channelwise( chans_to_predict=chans_to_predict, random_ch_picks=random_ch_picks, interpolation_mats=interpolation_mats, @@ -341,17 +448,14 @@ def _ransac_correlations( return channel_correlations -def _run_ransac( +def _predict_median_signals_channelwise( chans_to_predict, random_ch_picks, interpolation_mats, data, matlab_strict, ): - """Detect noisy channels apart from the ones described previously. - - It creates a random subset of the so-far good channels - and predicts the values of the channels not in the subset. + """Calculate the median RANSAC-predicted signal for a given chunk of channels. Parameters ---------- From 83b38d65b2161a1f6629da0b17672751e776ddbe Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Thu, 29 Apr 2021 16:28:04 -0300 Subject: [PATCH 03/10] Make channel-wise API similar to window-wise --- pyprep/ransac.py | 124 ++++++++++++++++++++++++----------------------- 1 file changed, 63 insertions(+), 61 deletions(-) diff --git a/pyprep/ransac.py b/pyprep/ransac.py index 961414ad..9346f21e 100644 --- a/pyprep/ransac.py +++ b/pyprep/ransac.py @@ -197,12 +197,12 @@ def find_bad_by_ransac( for chunk in channel_chunks: interp_mats_for_chunk = [mat[chunk, :] for mat in interp_mats] channel_correlations[:, good_idx[chunk]] = _ransac_by_channel( - chunk, - random_ch_picks, - interp_mats_for_chunk, data[good_idx, :], + interp_mats_for_chunk, win_size, win_count, + chunk, + random_ch_picks, matlab_strict, ) if chunk == channel_chunks[0]: @@ -292,22 +292,26 @@ def _ransac_by_window(data, interpolation_mats, win_size, win_count, matlab_stri Parameters ---------- data : np.ndarray - A 2-D array containing the EEG signals from currently-good channels. + A 2-D array containing the EEG signals from all currently-good channels. interpolation_mats : list of np.ndarray A list of interpolation matrices, one for each RANSAC sample of channels. win_size : int Number of frames/samples of EEG data in each RANSAC correlation window. win_count: int Number of RANSAC correlation windows. + matlab_strict : bool + Whether or not RANSAC should strictly follow MATLAB PREP's internal + math, ignoring any improvements made in PyPREP over the original code. Returns ------- - channel_correlations : np.ndarray + correlations : np.ndarray Correlations of the given channels to their predicted values within each RANSAC window. + """ ch_count = data.shape[0] - ch_correlations = np.ones((win_count, ch_count)) + correlations = np.ones((win_count, ch_count)) for window in range(win_count): @@ -320,9 +324,9 @@ def _ransac_by_window(data, interpolation_mats, win_size, win_count, matlab_stri predicted = _predict_median_signals(actual, interpolation_mats, matlab_strict) # Calculate the actual vs predicted signal correlation for each channel - ch_correlations[window, :] = _correlate_arrays(actual, predicted, matlab_strict) + correlations[window, :] = _correlate_arrays(actual, predicted, matlab_strict) - return ch_correlations + return correlations def _predict_median_signals(window, interpolation_mats, matlab_strict=False): @@ -335,9 +339,9 @@ def _predict_median_signals(window, interpolation_mats, matlab_strict=False): interpolation_mats : list of np.ndarray A set of channel interpolation matrices, one for each RANSAC sample of channels. - matlab_strict : bool, optional - Whether MATLAB PREP's internal logic should be strictly followed (see Notes). - Defaults to False. + matlab_strict : bool + Whether or not RANSAC should strictly follow MATLAB PREP's internal + math, ignoring any improvements made in PyPREP over the original code. Returns ------- @@ -372,12 +376,12 @@ def _predict_median_signals(window, interpolation_mats, matlab_strict=False): def _ransac_by_channel( + data, + interpolation_mats, + win_size, + win_count, chans_to_predict, random_ch_picks, - interpolation_mats, - data, - n, - w_correlation, matlab_strict, ): """Calculate correlations of channels with their RANSAC-predicted values. @@ -390,107 +394,107 @@ def _ransac_by_channel( Parameters ---------- - chans_to_predict : list of int - Indices of the channels to predict (as they appear in `data`). - random_ch_picks : list - Each element is a list of indexes of the channels (as they appear - in `data`) to use for reconstruction in each of the samples. + data : np.ndarray + A 2-D array containing the EEG signals from all currently-good channels. interpolation_mats : list of np.ndarray A set of channel interpolation matrices, one for each RANSAC sample of channels. - data : np.ndarray - 2-D EEG data - n : int - Number of frames/samples of each window. - w_correlation: int - Number of windows. + win_size : int + Number of frames/samples of EEG data in each RANSAC correlation window. + win_count: int + Number of RANSAC correlation windows. + chans_to_predict : list of int + Indices of the channels to predict (as they appear in `data`) within the + current chunk. + random_ch_picks : list of list of int + A list containing multiple random subsets of currently-good channels. matlab_strict : bool Whether or not RANSAC should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code. Returns ------- - channel_correlations : np.ndarray - correlations of the given channels to their RANSAC-predicted values. + correlations : np.ndarray + Correlations of the given channels to their predicted values within each + RANSAC window. """ - # Preallocate - channel_correlations = np.ones((w_correlation, len(chans_to_predict))) + # Preallocate RANSAC correlation matrix for current chunk + chunk_size = len(chans_to_predict) + correlations = np.ones((win_count, chunk_size)) - # Make the ransac predictions - ransac_eeg = _predict_median_signals_channelwise( - chans_to_predict=chans_to_predict, - random_ch_picks=random_ch_picks, - interpolation_mats=interpolation_mats, + # Get median RANSAC predictions for each channel in the current chunk + predicted_chans = _predict_median_signals_channelwise( data=data, + interpolation_mats=interpolation_mats, + random_ch_picks=random_ch_picks, + chunk_size=len(chans_to_predict), matlab_strict=matlab_strict, ) # Correlate ransac prediction and eeg data # For the actual data - data_window = data[chans_to_predict, : n * w_correlation] - data_window = data_window.reshape(len(chans_to_predict), w_correlation, n) + data_window = data[chans_to_predict, : win_size * win_count] + data_window = data_window.reshape(chunk_size, win_count, win_size) data_window = data_window.swapaxes(1, 0) # For the ransac predicted eeg - pred_window = ransac_eeg[: len(chans_to_predict), : n * w_correlation] - pred_window = pred_window.reshape(len(chans_to_predict), w_correlation, n) + pred_window = predicted_chans[: chunk_size, : win_size * win_count] + pred_window = pred_window.reshape(chunk_size, win_count, win_size) pred_window = pred_window.swapaxes(1, 0) # Perform correlations - for k in range(w_correlation): + for k in range(win_count): data_portion = data_window[k, :, :] pred_portion = pred_window[k, :, :] R = _correlate_arrays(data_portion, pred_portion, matlab_strict) - channel_correlations[k, :] = R + correlations[k, :] = R - return channel_correlations + return correlations def _predict_median_signals_channelwise( - chans_to_predict, - random_ch_picks, - interpolation_mats, data, + interpolation_mats, + random_ch_picks, + chunk_size, matlab_strict, ): """Calculate the median RANSAC-predicted signal for a given chunk of channels. Parameters ---------- - chans_to_predict : list of int - Indices of the channels to predict (as they appear in `data`). - random_ch_picks : list - Each element is a list of indexes of the channels (as they appear - in `data`) to use for reconstruction in each of the samples. + data : np.ndarray + A 2-D array containing the EEG signals from all currently-good channels. interpolation_mats : list of np.ndarray A set of channel interpolation matrices, one for each RANSAC sample of channels. - data : np.ndarray - 2-D EEG data + random_ch_picks : list of list of int + A list containing multiple random subsets of currently-good channels. + chunk_size : int + The number of channels to predict in the current chunk. matlab_strict : bool Whether or not RANSAC should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code. Returns ------- - ransac_eeg : np.ndarray - The EEG data predicted by RANSAC + predicted_chans : np.ndarray + The median RANSAC-predicted EEG signals for the given chunk of channels. """ # n_chns, n_timepts = data.shape # 2 next lines should be equivalent but support single channel processing ransac_samples = len(interpolation_mats) n_timepts = data.shape[1] - n_chns = len(chans_to_predict) # Before running, make sure we have enough memory - verify_free_ram(data, ransac_samples, n_chns) + verify_free_ram(data, ransac_samples, chunk_size) # Memory seems to be fine ... # Make the predictions - eeg_predictions = np.zeros((n_chns, n_timepts, ransac_samples)) + eeg_predictions = np.zeros((chunk_size, n_timepts, ransac_samples)) for sample in range(ransac_samples): # Get the random channels & interpolation matrix for the current sample reconstr_idx = random_ch_picks[sample] @@ -503,8 +507,6 @@ def _predict_median_signals_channelwise( # Match MATLAB's rounding logic (.5 always rounded up) median_idx = int(_mat_round(ransac_samples / 2.0) - 1) eeg_predictions.sort(axis=-1) - ransac_eeg = eeg_predictions[:, :, median_idx] + return eeg_predictions[:, :, median_idx] else: - ransac_eeg = np.median(eeg_predictions, axis=-1, overwrite_input=True) - - return ransac_eeg + return np.median(eeg_predictions, axis=-1, overwrite_input=True) From 354f6354ab45c28590394bbe82e812f987cafd41 Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Thu, 29 Apr 2021 18:45:01 -0300 Subject: [PATCH 04/10] Added function for printing window-wise progress --- pyprep/ransac.py | 6 +++++- pyprep/utils.py | 38 ++++++++++++++++++++++++++++++++++++++ tests/test_utils.py | 42 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 84 insertions(+), 2 deletions(-) diff --git a/pyprep/ransac.py b/pyprep/ransac.py index 9346f21e..7d0a791b 100644 --- a/pyprep/ransac.py +++ b/pyprep/ransac.py @@ -5,7 +5,8 @@ from mne.utils import check_random_state from pyprep.utils import ( - split_list, verify_free_ram, _get_random_subset, _mat_round, _correlate_arrays + _get_random_subset, _mat_round, _correlate_arrays, print_progress, + split_list, verify_free_ram ) @@ -315,6 +316,9 @@ def _ransac_by_window(data, interpolation_mats, win_size, win_count, matlab_stri for window in range(win_count): + # Print RANSAC progress in 10% increments + print_progress(window + 1, win_count, every=0.1) + # Get the current window of EEG data start = window * win_size end = (window + 1) * win_size diff --git a/pyprep/utils.py b/pyprep/utils.py index 51c95185..47aa8e9b 100644 --- a/pyprep/utils.py +++ b/pyprep/utils.py @@ -362,6 +362,44 @@ def split_list(mylist, chunk_size): ] +def print_progress(current, end, start=None, stepsize=1, every=0.1): + """Print the current progress in a loop. + + Parameters + ---------- + current: {int, float} + The index or numeric value of the current position in the loop. + end: {int, float} + The final index or numeric value in the loop. + start: {int, float, None}, optional + The first index or numeric value in the loop. If ``None``, the start + index will assumed to be `stepsize` (i.e., 3 if `stepsize` is 3). + Defaults to ``None``. + stepsize: {int, float}, optional + The fixed amount by which `current` increases every iteration of the + loop. Defaults to ``1``. + every: float, optional + The frequency with which to print progress updates during the loop, + as a proportion between 0 and 1, exclusive. Defaults to ``0.1``, which + prints a progress update after every 10%. + + """ + start = stepsize if not start else start + end = end - start + 1 + current = current - start + 1 + + if current == 1: + print("Progress:", end=" ", flush=True) + elif current == end: + print("100%") + elif current > 0: + progress = float(current) / end + last = float(current - stepsize) / end + if int(progress / every) > int(last / every): + pct = int(progress / every) * every * 100 + print("{0}%...".format(int(pct)), end=" ", flush=True) + + def make_random_mne_object( ch_names, ch_types, diff --git a/tests/test_utils.py b/tests/test_utils.py index 8cf35476..624fdfdf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ from pyprep.utils import ( _mat_round, _mat_quantile, _mat_iqr, _get_random_subset, _correlate_arrays, - _eeglab_create_highpass + _eeglab_create_highpass, print_progress ) @@ -114,3 +114,43 @@ def test_eeglab_create_highpass(): expected_val = 0.9961 actual_val = vals[len(vals) // 2] assert np.isclose(expected_val, actual_val, atol=0.001) + + +def test_print_progress(capsys): + """Test the function for printing progress updates within a loop.""" + # Test printing start value + print_progress(1, 20) + captured = capsys.readouterr() + assert captured.out == "Progress: " + + # Test printing end values + iterations = 27 + for i in range(iterations): + print_progress(i + 1, iterations, every=0.2) + captured = capsys.readouterr() + assert captured.out == "Progress: 20%... 40%... 60%... 80%... 100%\n" + + # Test printing of updates at right times + iterations = 176 + for i in range(iterations): + print_progress(i + 1, iterations) + if (i + 1) == 17: + captured = capsys.readouterr() + assert captured.out == "Progress: " + elif (i + 1) == 18: + captured = capsys.readouterr() + assert captured.out == "10%... " + break + + # Test shifted start value + iterations = 25 + start = 5 + for i in range(start, iterations + 1): + print_progress(i, iterations, start=start) + if i == 6: + captured = capsys.readouterr() + assert captured.out == "Progress: " + elif i == 7: + captured = capsys.readouterr() + assert captured.out == "10%... " + break From b13b87411a315c0cbe09e2a0a6bf194eda8379c4 Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Thu, 29 Apr 2021 19:15:34 -0300 Subject: [PATCH 05/10] Change meaning of channel_wise, add max chunk size --- examples/run_ransac.py | 9 +++++---- pyprep/find_noisy_channels.py | 19 ++++++++++++++----- pyprep/ransac.py | 32 ++++++++++++++++++++------------ 3 files changed, 39 insertions(+), 21 deletions(-) diff --git a/examples/run_ransac.py b/examples/run_ransac.py index bfff7b0e..93f2705d 100644 --- a/examples/run_ransac.py +++ b/examples/run_ransac.py @@ -74,14 +74,15 @@ nd2 = NoisyChannels(raw) ############################################################################### -# Find all bad channels and print a summary +# Find all bad channels using channel-wise RANSAC and print a summary start_time = perf_counter() -nd.find_bad_by_ransac() +nd.find_bad_by_ransac(channel_wise=True) print("--- %s seconds ---" % (perf_counter() - start_time)) -# Repeat RANSAC in a channel wise manner. This is slower but needs less memory. +# Repeat channel-wise RANSAC using a single channel at a time. This is slower +# but needs less memory. start_time = perf_counter() -nd2.find_bad_by_ransac(channel_wise=True) +nd2.find_bad_by_ransac(channel_wise=True, max_chunk_size=1) print("--- %s seconds ---" % (perf_counter() - start_time)) ############################################################################### diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index d788fd27..7627b61d 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -409,6 +409,7 @@ def find_bad_by_ransac( fraction_bad=0.4, corr_window_secs=5.0, channel_wise=False, + max_chunk_size=None, ): """Detect channels that are not predicted well by other channels. @@ -447,10 +448,18 @@ def find_bad_by_ransac( The duration (in seconds) of each RANSAC correlation window. Defaults to 5 seconds. channel_wise : bool, optional - Whether RANSAC should be performed one channel at a time (lower RAM - demands) or in chunks of as many channels as can fit into the - currently available RAM (faster). Defaults to ``False`` (i.e., using - the faster method). + Whether RANSAC should predict signals for whole chunks of channels + at once instead of predicting signals for each RANSAC window + individually. Channel-wise RANSAC generally has higher RAM demands + than window-wise RANSAC (especially if `max_chunk_size` is + ``None``), but can be faster on systems with lots of RAM to spare. + Defaults to ``False``. + max_chunk_size : {int, None}, optional + The maximum number of channels to predict at once during + channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk + size that will fit into the available RAM, which may slow down + other programs on the host system. If using window-wise RANSAC + (the default), this parameter has no effect. Defaults to ``None``. References ---------- @@ -479,7 +488,7 @@ def find_bad_by_ransac( fraction_bad, corr_window_secs, channel_wise, - False, + max_chunk_size, self.random_state, self.matlab_strict, ) diff --git a/pyprep/ransac.py b/pyprep/ransac.py index 7d0a791b..163708d7 100644 --- a/pyprep/ransac.py +++ b/pyprep/ransac.py @@ -22,7 +22,7 @@ def find_bad_by_ransac( fraction_bad=0.4, corr_window_secs=5.0, channel_wise=False, - window_wise=False, + max_chunk_size=None, random_state=None, matlab_strict=False, ): @@ -72,10 +72,17 @@ def find_bad_by_ransac( The duration (in seconds) of each RANSAC correlation window. Defaults to 5 seconds. channel_wise : bool, optional - Whether RANSAC should be performed one channel at a time (lower RAM - demands) or in chunks of as many channels as can fit into the currently - available RAM (faster). Defaults to ``False`` (i.e., using the faster - method). + Whether RANSAC should predict signals for whole chunks of channels at + once instead of predicting signals for each RANSAC window individually. + Channel-wise RANSAC generally has higher RAM demands than window-wise + RANSAC (especially if `max_chunk_size` is ``None``), but can be faster + on systems with lots of RAM to spare. Defaults to ``False``. + max_chunk_size : {int, None}, optional + The maximum number of channels to predict at once during channel-wise + RANSAC. If ``None``, RANSAC will use the largest chunk size that will + fit into the available RAM, which may slow down other programs on the + host system. If using window-wise RANSAC (the default), this parameter + has no effect. Defaults to ``None``. random_state : {int, None, np.random.RandomState}, optional The random seed with which to generate random samples of channels during RANSAC. If random_state is an int, it will be used as a seed for RandomState. @@ -133,11 +140,11 @@ def find_bad_by_ransac( # Before running, make sure we have enough memory when using the # smallest possible chunk size - if window_wise: + if channel_wise: + verify_free_ram(data, n_samples, 1) + else: window_size = int(sample_rate * corr_window_secs) verify_free_ram(data[:, :window_size], n_samples, n_chans_good) - else: - verify_free_ram(data, n_samples, 1) # Generate random channel picks for each RANSAC sample random_ch_picks = [] @@ -170,7 +177,7 @@ def find_bad_by_ransac( print("Executing RANSAC\nThis may take a while, so be patient...") # If enabled, run window-wise RANSAC - if window_wise: + if not channel_wise: # Get correlations between actual vs predicted signals for each RANSAC window channel_correlations[:, good_idx] = _ransac_by_window( data[good_idx, :], interp_mats, win_size, win_count, matlab_strict @@ -183,14 +190,15 @@ def find_bad_by_ransac( n_chunks = int(np.ceil(n_chans_good / i)) if n_chunks != chunk_count: chunk_count = n_chunks - chunk_sizes.append(i) + if not max_chunk_size or i <= max_chunk_size: + chunk_sizes.append(i) - chunk_size = 1 if channel_wise else chunk_sizes.pop() + chunk_size = chunk_sizes.pop() mem_error = True job = list(range(n_chans_good)) # If not using window-wise RANSAC, do channel-wise RANSAC - while mem_error and not window_wise: + while mem_error and channel_wise: try: channel_chunks = split_list(job, chunk_size) total_chunks = len(channel_chunks) From 629950b0b892abbbad7b49f46ba1a47fff4dc744 Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Thu, 29 Apr 2021 23:30:21 -0300 Subject: [PATCH 06/10] Overhaul/extend unit tests for RANSAC --- tests/test_find_noisy_channels.py | 79 +++++++++++++++++++------------ 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/tests/test_find_noisy_channels.py b/tests/test_find_noisy_channels.py index 4e79558a..c54d4add 100644 --- a/tests/test_find_noisy_channels.py +++ b/tests/test_find_noisy_channels.py @@ -3,6 +3,7 @@ import pytest from pyprep.find_noisy_channels import NoisyChannels +from pyprep.removeTrend import removeTrend @pytest.mark.usefixtures("raw", "montage") @@ -111,48 +112,68 @@ def test_findnoisychannels(raw, montage): nd.find_bad_by_SNR() assert rand_chn_lab in nd.bad_by_SNR - # Test for finding bad channels by RANSAC - raw_tmp = raw.copy() - # Ransac identifies channels that go bad together and are highly correlated. - # Inserting highly correlated signal in channels 0 through 3 at 30 Hz - raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6 - nd = NoisyChannels(raw_tmp, random_state=rng) - nd.find_bad_by_ransac() - bads = nd.bad_by_ransac - assert bads == raw_tmp.ch_names[0:6] - # Test for finding bad channels by matlab_strict RANSAC - raw_tmp = raw.copy() - # Ransac identifies channels that go bad together and are highly correlated. - # Inserting highly correlated signal in channels 0 through 3 at 30 Hz - raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6 - nd = NoisyChannels(raw_tmp, random_state=rng, matlab_strict=True) - nd.find_bad_by_ransac() - bads = nd.bad_by_ransac - assert bads == raw_tmp.ch_names[0:6] +@pytest.mark.usefixtures("raw", "montage") +def test_find_bad_by_ransac(raw, montage): + """Test the RANSAC component of NoisyChannels.""" + # Set a fixed random seed and a montage for the tests + rng = 435656 + raw.set_montage(montage) - # Test for finding bad channels by channel-wise RANSAC - raw_tmp = raw.copy() - # Ransac identifies channels that go bad together and are highly correlated. + # RANSAC identifies channels that go bad together and are highly correlated. # Inserting highly correlated signal in channels 0 through 3 at 30 Hz - raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6 - nd = NoisyChannels(raw_tmp, random_state=rng) - nd.find_bad_by_ransac(channel_wise=True) - bads = nd.bad_by_ransac - assert bads == raw_tmp.ch_names[0:6] - - # Test not-enough-memory and n_samples type exceptions raw_tmp = raw.copy() raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6 - nd = NoisyChannels(raw_tmp, random_state=rng) + + # Pre-detrend data to save time during NoisyChannels initialization + raw_tmp._data = removeTrend(raw_tmp.get_data(), raw.info["sfreq"]) + + # Run different variations of RANSAC on the same data + test_matrix = { + # List items represent [matlab_strict, channel_wise, max_chunk_size] + 'by_window': [False, False, None], + 'by_channel': [False, True, None], + 'by_channel_maxchunk': [False, True, 2], + 'by_window_strict': [True, False, None], + 'by_channel_strict': [True, True, None] + } + bads = {} + corr = {} + for name, args in test_matrix.items(): + nd = NoisyChannels( + raw_tmp, do_detrend=False, random_state=rng, matlab_strict=args[0] + ) + nd.find_bad_by_ransac(channel_wise=args[1], max_chunk_size=args[2]) + # Save bad channels and RANSAC correlation matrix for later comparison + bads[name] = nd.bad_by_ransac + corr[name] = nd._extra_info['bad_by_ransac']['ransac_correlations'] + + # Test whether all methods detected bad channels properly + assert bads['by_window'] == raw_tmp.ch_names[0:6] + assert bads['by_channel'] == raw_tmp.ch_names[0:6] + assert bads['by_channel_maxchunk'] == raw_tmp.ch_names[0:6] + assert bads['by_window_strict'] == raw_tmp.ch_names[0:6] + assert bads['by_channel_strict'] == raw_tmp.ch_names[0:6] + + # Make sure non-strict correlation matrices all match + assert np.allclose(corr['by_window'], corr['by_channel']) + assert np.allclose(corr['by_window'], corr['by_channel_maxchunk']) + + # Make sure MATLAB-strict correlation matrices match + assert np.allclose(corr['by_window_strict'], corr['by_channel_strict']) + + # Make sure strict and non-strict matrices differ + assert not np.allclose(corr['by_window'], corr['by_window_strict']) # Set n_samples very very high to trigger a memory error n_samples = int(1e100) + nd = NoisyChannels(raw_tmp, do_detrend=False, random_state=rng) with pytest.raises(MemoryError): nd.find_bad_by_ransac(n_samples=n_samples) # Set n_samples to a float to trigger a type error n_samples = 35.5 + nd = NoisyChannels(raw_tmp, do_detrend=False, random_state=rng) with pytest.raises(TypeError): nd.find_bad_by_ransac(n_samples=n_samples) From 94d157fa4733bf2f945720566d5a014526528c51 Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Sat, 1 May 2021 00:23:17 -0300 Subject: [PATCH 07/10] Add public API for channel-wise/max chunk settings --- pyprep/find_noisy_channels.py | 28 +++++++++++++++++++++++---- pyprep/prep_pipeline.py | 25 +++++++++++++++++++++--- pyprep/reference.py | 36 +++++++++++++++++++++++++++++------ 3 files changed, 76 insertions(+), 13 deletions(-) diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index 7627b61d..9bf61ec1 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -144,22 +144,42 @@ def get_bads(self, verbose=False): ) return bads - def find_all_bads(self, ransac=True): + def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None): """Call all the functions to detect bad channels. This function calls all the bad-channel detecting functions. Parameters ---------- - ransac : bool - To detect channels by ransac or not. + ransac : bool, optional + Whether RANSAC should be for bad channel detection, in addition to + the other methods. RANSAC can detect bad channels that other methods + are unable to catch, but also slows down noisy channel detection + considerably. Defaults to ``True``. + channel_wise : bool, optional + Whether RANSAC should predict signals for whole chunks of channels + at once instead of predicting signals for each RANSAC window + individually. Channel-wise RANSAC generally has higher RAM demands + than window-wise RANSAC (especially if `max_chunk_size` is + ``None``), but can be faster on systems with lots of RAM to spare. + Has no effect if not using RANSAC. Defaults to ``False``. + max_chunk_size : {int, None}, optional + The maximum number of channels to predict at once during + channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk + size that will fit into the available RAM, which may slow down + other programs on the host system. If using window-wise RANSAC + (the default) or not using RANSAC at all, this parameter has no + effect. Defaults to ``None``. """ self.find_bad_by_nan_flat() self.find_bad_by_deviation() self.find_bad_by_SNR() if ransac: - self.find_bad_by_ransac() + self.find_bad_by_ransac( + channel_wise=channel_wise, + max_chunk_size=max_chunk_size + ) def find_bad_by_nan_flat(self): """Detect channels that appear flat or have NaN values.""" diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index 2c4a38b8..3ca58a90 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -38,6 +38,19 @@ class PrepPipeline: ransac : bool, optional Whether or not to use RANSAC for noisy channel detection in addition to the other methods in :class:`~pyprep.NoisyChannels`. Defaults to True. + channel_wise : bool, optional + Whether RANSAC should predict signals for whole chunks of channels at + once instead of predicting signals for each RANSAC window + individually. Channel-wise RANSAC generally has higher RAM demands than + window-wise RANSAC (especially if `max_chunk_size` is ``None``), but can + be faster on systems with lots of RAM to spare.nHas no effect if not + using RANSAC. Defaults to ``False``. + max_chunk_size : {int, None}, optional + The maximum number of channels to predict at once during channel-wise + RANSAC. If ``None``, RANSAC will use the largest chunk size that will + fit into the available RAM, which may slow down other programs on the + host system. If using window-wise RANSAC (the default) or not using + RANSAC at all, this parameter has no effect. Defaults to ``None``. random_state : {int, None, np.random.RandomState}, optional The random seed at which to initialize the class. If random_state is an int, it will be used as a seed for RandomState. @@ -99,6 +112,8 @@ def __init__( prep_params, montage, ransac=True, + channel_wise=False, + max_chunk_size=None, random_state=None, filter_kwargs=None, matlab_strict=False, @@ -133,7 +148,11 @@ def __init__( if self.prep_params["reref_chs"] == "eeg": self.prep_params["reref_chs"] = self.ch_names_eeg self.sfreq = self.raw_eeg.info["sfreq"] - self.ransac = ransac + self.ransac_settings = { + 'ransac': ransac, + 'channel_wise': channel_wise, + 'max_chunk_size': max_chunk_size + } self.random_state = check_random_state(random_state) self.filter_kwargs = filter_kwargs self.matlab_strict = matlab_strict @@ -189,9 +208,9 @@ def fit(self): reference = Reference( self.raw_eeg, self.prep_params, - ransac=self.ransac, random_state=self.random_state, - matlab_strict=self.matlab_strict + matlab_strict=self.matlab_strict, + **self.ransac_settings ) reference.perform_reference() self.raw_eeg = reference.raw diff --git a/pyprep/reference.py b/pyprep/reference.py index 0cdddf34..0d394f1f 100644 --- a/pyprep/reference.py +++ b/pyprep/reference.py @@ -32,6 +32,19 @@ class Reference: ransac : bool, optional Whether or not to use RANSAC for noisy channel detection in addition to the other methods in :class:`~pyprep.NoisyChannels`. Defaults to True. + channel_wise : bool, optional + Whether RANSAC should predict signals for whole chunks of channels at + once instead of predicting signals for each RANSAC window + individually. Channel-wise RANSAC generally has higher RAM demands than + window-wise RANSAC (especially if `max_chunk_size` is ``None``), but can + be faster on systems with lots of RAM to spare.nHas no effect if not + using RANSAC. Defaults to ``False``. + max_chunk_size : {int, None}, optional + The maximum number of channels to predict at once during channel-wise + RANSAC. If ``None``, RANSAC will use the largest chunk size that will + fit into the available RAM, which may slow down other programs on the + host system. If using window-wise RANSAC (the default) or not using + RANSAC at all, this parameter has no effect. Defaults to ``None``. random_state : {int, None, np.random.RandomState}, optional The random seed at which to initialize the class. If random_state is an int, it will be used as a seed for RandomState. @@ -51,7 +64,14 @@ class Reference: """ def __init__( - self, raw, params, ransac=True, random_state=None, matlab_strict=False + self, + raw, + params, + ransac=True, + channel_wise=False, + max_chunk_size=None, + random_state=None, + matlab_strict=False ): """Initialize the class.""" self.raw = raw.copy() @@ -62,7 +82,11 @@ def __init__( self.reference_channels = params["ref_chs"] self.rereferenced_channels = params["reref_chs"] self.sfreq = self.raw.info["sfreq"] - self.ransac = ransac + self.ransac_settings = { + 'ransac': ransac, + 'channel_wise': channel_wise, + 'max_chunk_size': max_chunk_size + } self.random_state = check_random_state(random_state) self._extra_info = {} self.matlab_strict = matlab_strict @@ -103,7 +127,7 @@ def perform_reference(self): noisy_detector = NoisyChannels( self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict ) - noisy_detector.find_all_bads(ransac=self.ransac) + noisy_detector.find_all_bads(**self.ransac_settings) # Record Noisy channels and EEG before interpolation self.bad_before_interpolation = noisy_detector.get_bads(verbose=True) @@ -141,7 +165,7 @@ def perform_reference(self): noisy_detector = NoisyChannels( self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict ) - noisy_detector.find_all_bads(ransac=self.ransac) + noisy_detector.find_all_bads(**self.ransac_settings) self.still_noisy_channels = noisy_detector.get_bads() self.raw.info["bads"] = self.still_noisy_channels self.noisy_channels_after_interpolation = { @@ -186,7 +210,7 @@ def robust_reference(self): random_state=self.random_state, matlab_strict=self.matlab_strict ) - noisy_detector.find_all_bads(ransac=self.ransac) + noisy_detector.find_all_bads(**self.ransac_settings) self.noisy_channels_original = { "bad_by_nan": noisy_detector.bad_by_nan, "bad_by_flat": noisy_detector.bad_by_flat, @@ -243,7 +267,7 @@ def robust_reference(self): matlab_strict=self.matlab_strict ) # Detrend applied at the beginning of the function. - noisy_detector.find_all_bads(ransac=self.ransac) + noisy_detector.find_all_bads(**self.ransac_settings) self.noisy_channels["bad_by_nan"] = _union( self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan ) From 56e7677618d414302c94b6c5f946982d5c7c4b70 Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Sat, 1 May 2021 00:23:56 -0300 Subject: [PATCH 08/10] Updated whats_new, minor style fix --- docs/whats_new.rst | 3 +++ pyprep/ransac.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/whats_new.rst b/docs/whats_new.rst index f28ca968..0b7cbce4 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -41,6 +41,8 @@ Changelog - Changed RANSAC so that "bad by high-frequency noise" channels are retained when making channel predictions (provided they aren't flagged as bad by any other metric), matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`64`) - Added a new flag ``matlab_strict`` to :class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, :class:`~pyprep.NoisyChannels`, and :func:`~pyprep.ransac.find_bad_by_ransac` for optionally matching MATLAB PREP's internal math as closely as possible, overriding areas where PyPREP attempts to improve on the original, by `Austin Hurst`_ (:gh:`70`) - Added a ``matlab_strict`` method for high-pass trend removal, exactly matching MATLAB PREP's values if ``matlab_strict`` is enabled, by `Austin Hurst`_ (:gh:`71`) +- Added a window-wise implementaion of RANSAC and made it the default method, reducing the typical RAM demands of robust re-referencing considerably, by `Austin Hurst`_ (:gh:`66`) +- Added `max_chunk_size` parameter for specifying the maximum chunk size to use for channel-wise RANSAC, allowing more control over PyPREP RAM usage, by `Austin Hurst`_ (:gh:`66`) Bug ~~~ @@ -55,6 +57,7 @@ API - The permissible parameters for the following methods were removed and/or reordered: `ransac._ransac_correlations`, `ransac._run_ransac`, and `ransac._get_ransac_pred` methods, by `Yorguin Mantilla`_ (:gh:`51`) - The following methods have been moved to a new module named :mod:`~pyprep.ransac` and are now private: `NoisyChannels.ransac_correlations`, `NoisyChannels.run_ransac`, and `NoisyChannels.get_ransac_pred` methods, by `Yorguin Mantilla`_ (:gh:`51`) - The permissible parameters for the following methods were removed and/or reordered: `NoisyChannels.ransac_correlations`, `NoisyChannels.run_ransac`, and `NoisyChannels.get_ransac_pred` methods, by `Austin Hurst`_ and `Yorguin Mantilla`_ (:gh:`43`) +- Changed the meaning of the argument `channel_wise` in :meth:`~pyprep.NoisyChannels.find_bad_by_ransac` to mean 'perform RANSAC across chunks of channels instead of window-wise', from its original meaning of 'perform channel-wise RANSAC one channel at a time', by `Austin Hurst`_ (:gh:`66`) .. _changes_0_3_1: diff --git a/pyprep/ransac.py b/pyprep/ransac.py index 163708d7..3b799154 100644 --- a/pyprep/ransac.py +++ b/pyprep/ransac.py @@ -364,7 +364,7 @@ def _predict_median_signals(window, interpolation_mats, matlab_strict=False): ----- In MATLAB PREP, the median signal is calculated by sorting the different predictions for each EEG sample/channel from low to high and then taking the value - at the middle index (as calculated by `int(n_ransac_samples / 2.0)`) for each. + at the middle index (as calculated by ``int(n_ransac_samples / 2.0)``) for each. Because this logic only returns the correct result for odd numbers of samples, the current function will instead return the true median signal across predictions unless strict MATLAB equivalence is requested. From 7cbe3bff249f761139d663ab05bd35a1f5b600bd Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Sat, 1 May 2021 09:22:13 -0300 Subject: [PATCH 09/10] Update docstrings based on review --- pyprep/find_noisy_channels.py | 37 ++++++++++++++++++++--------------- pyprep/prep_pipeline.py | 15 ++++++++------ pyprep/ransac.py | 13 +++++++----- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index 9bf61ec1..43b8e936 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -152,17 +152,20 @@ def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None): Parameters ---------- ransac : bool, optional - Whether RANSAC should be for bad channel detection, in addition to - the other methods. RANSAC can detect bad channels that other methods - are unable to catch, but also slows down noisy channel detection - considerably. Defaults to ``True``. + Whether RANSAC should be used for bad channel detection, in addition + to the other methods. RANSAC can detect bad channels that other + methods are unable to catch, but also slows down noisy channel + detection considerably. Defaults to ``True``. channel_wise : bool, optional - Whether RANSAC should predict signals for whole chunks of channels - at once instead of predicting signals for each RANSAC window - individually. Channel-wise RANSAC generally has higher RAM demands - than window-wise RANSAC (especially if `max_chunk_size` is - ``None``), but can be faster on systems with lots of RAM to spare. - Has no effect if not using RANSAC. Defaults to ``False``. + Whether RANSAC should predict signals for chunks of channels over the + entire signal length ("channel-wise RANSAC", see `max_chunk_size` + parameter). If ``False``, RANSAC will instead predict signals for all + channels at once but over a number of smaller time windows instead of + over the entirem signal length ("window-wise RANSAC"). Channel-wise + RANSAC generally has higher RAM demands than window-wise RANSAC + (especially if `max_chunk_size` is ``None``), but can be faster on + systems with lots of RAM to spare. Has no effect if not using RANSAC. + Defaults to ``False``. max_chunk_size : {int, None}, optional The maximum number of channels to predict at once during channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk @@ -468,12 +471,14 @@ def find_bad_by_ransac( The duration (in seconds) of each RANSAC correlation window. Defaults to 5 seconds. channel_wise : bool, optional - Whether RANSAC should predict signals for whole chunks of channels - at once instead of predicting signals for each RANSAC window - individually. Channel-wise RANSAC generally has higher RAM demands - than window-wise RANSAC (especially if `max_chunk_size` is - ``None``), but can be faster on systems with lots of RAM to spare. - Defaults to ``False``. + Whether RANSAC should predict signals for chunks of channels over the + entire signal length ("channel-wise RANSAC", see `max_chunk_size` + parameter). If ``False``, RANSAC will instead predict signals for all + channels at once but over a number of smaller time windows instead of + over the entirem signal length ("window-wise RANSAC"). Channel-wise + RANSAC generally has higher RAM demands than window-wise RANSAC + (especially if `max_chunk_size` is ``None``), but can be faster on + systems with lots of RAM to spare. Defaults to ``False``. max_chunk_size : {int, None}, optional The maximum number of channels to predict at once during channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index 3ca58a90..205c5c83 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -39,12 +39,15 @@ class PrepPipeline: Whether or not to use RANSAC for noisy channel detection in addition to the other methods in :class:`~pyprep.NoisyChannels`. Defaults to True. channel_wise : bool, optional - Whether RANSAC should predict signals for whole chunks of channels at - once instead of predicting signals for each RANSAC window - individually. Channel-wise RANSAC generally has higher RAM demands than - window-wise RANSAC (especially if `max_chunk_size` is ``None``), but can - be faster on systems with lots of RAM to spare.nHas no effect if not - using RANSAC. Defaults to ``False``. + Whether RANSAC should predict signals for chunks of channels over the + entire signal length ("channel-wise RANSAC", see `max_chunk_size` + parameter). If ``False``, RANSAC will instead predict signals for all + channels at once but over a number of smaller time windows instead of + over the entirem signal length ("window-wise RANSAC"). Channel-wise + RANSAC generally has higher RAM demands than window-wise RANSAC + (especially if `max_chunk_size` is ``None``), but can be faster on + systems with lots of RAM to spare. Has no effect if not using RANSAC. + Defaults to ``False``. max_chunk_size : {int, None}, optional The maximum number of channels to predict at once during channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk size that will diff --git a/pyprep/ransac.py b/pyprep/ransac.py index 3b799154..fcb16f68 100644 --- a/pyprep/ransac.py +++ b/pyprep/ransac.py @@ -72,11 +72,14 @@ def find_bad_by_ransac( The duration (in seconds) of each RANSAC correlation window. Defaults to 5 seconds. channel_wise : bool, optional - Whether RANSAC should predict signals for whole chunks of channels at - once instead of predicting signals for each RANSAC window individually. - Channel-wise RANSAC generally has higher RAM demands than window-wise - RANSAC (especially if `max_chunk_size` is ``None``), but can be faster - on systems with lots of RAM to spare. Defaults to ``False``. + Whether RANSAC should predict signals for chunks of channels over the + entire signal length ("channel-wise RANSAC", see `max_chunk_size` + parameter). If ``False``, RANSAC will instead predict signals for all + channels at once but over a number of smaller time windows instead of + over the entirem signal length ("window-wise RANSAC"). Channel-wise + RANSAC generally has higher RAM demands than window-wise RANSAC + (especially if `max_chunk_size` is ``None``), but can be faster on + systems with lots of RAM to spare. Defaults to ``False``. max_chunk_size : {int, None}, optional The maximum number of channels to predict at once during channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk size that will From 96b2965b18dd689d1ed0011c8e77e1e3bc0feeea Mon Sep 17 00:00:00 2001 From: Stefan Appelhoff Date: Sat, 1 May 2021 17:26:47 +0200 Subject: [PATCH 10/10] fix typos, update one more param docstr --- pyprep/find_noisy_channels.py | 4 ++-- pyprep/prep_pipeline.py | 2 +- pyprep/ransac.py | 2 +- pyprep/reference.py | 15 +++++++++------ 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index 43b8e936..6fbf9d99 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -161,7 +161,7 @@ def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None): entire signal length ("channel-wise RANSAC", see `max_chunk_size` parameter). If ``False``, RANSAC will instead predict signals for all channels at once but over a number of smaller time windows instead of - over the entirem signal length ("window-wise RANSAC"). Channel-wise + over the entire signal length ("window-wise RANSAC"). Channel-wise RANSAC generally has higher RAM demands than window-wise RANSAC (especially if `max_chunk_size` is ``None``), but can be faster on systems with lots of RAM to spare. Has no effect if not using RANSAC. @@ -475,7 +475,7 @@ def find_bad_by_ransac( entire signal length ("channel-wise RANSAC", see `max_chunk_size` parameter). If ``False``, RANSAC will instead predict signals for all channels at once but over a number of smaller time windows instead of - over the entirem signal length ("window-wise RANSAC"). Channel-wise + over the entire signal length ("window-wise RANSAC"). Channel-wise RANSAC generally has higher RAM demands than window-wise RANSAC (especially if `max_chunk_size` is ``None``), but can be faster on systems with lots of RAM to spare. Defaults to ``False``. diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index 205c5c83..6781b09a 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -43,7 +43,7 @@ class PrepPipeline: entire signal length ("channel-wise RANSAC", see `max_chunk_size` parameter). If ``False``, RANSAC will instead predict signals for all channels at once but over a number of smaller time windows instead of - over the entirem signal length ("window-wise RANSAC"). Channel-wise + over the entire signal length ("window-wise RANSAC"). Channel-wise RANSAC generally has higher RAM demands than window-wise RANSAC (especially if `max_chunk_size` is ``None``), but can be faster on systems with lots of RAM to spare. Has no effect if not using RANSAC. diff --git a/pyprep/ransac.py b/pyprep/ransac.py index fcb16f68..3e1b86a0 100644 --- a/pyprep/ransac.py +++ b/pyprep/ransac.py @@ -76,7 +76,7 @@ def find_bad_by_ransac( entire signal length ("channel-wise RANSAC", see `max_chunk_size` parameter). If ``False``, RANSAC will instead predict signals for all channels at once but over a number of smaller time windows instead of - over the entirem signal length ("window-wise RANSAC"). Channel-wise + over the entire signal length ("window-wise RANSAC"). Channel-wise RANSAC generally has higher RAM demands than window-wise RANSAC (especially if `max_chunk_size` is ``None``), but can be faster on systems with lots of RAM to spare. Defaults to ``False``. diff --git a/pyprep/reference.py b/pyprep/reference.py index 0d394f1f..59108676 100644 --- a/pyprep/reference.py +++ b/pyprep/reference.py @@ -33,12 +33,15 @@ class Reference: Whether or not to use RANSAC for noisy channel detection in addition to the other methods in :class:`~pyprep.NoisyChannels`. Defaults to True. channel_wise : bool, optional - Whether RANSAC should predict signals for whole chunks of channels at - once instead of predicting signals for each RANSAC window - individually. Channel-wise RANSAC generally has higher RAM demands than - window-wise RANSAC (especially if `max_chunk_size` is ``None``), but can - be faster on systems with lots of RAM to spare.nHas no effect if not - using RANSAC. Defaults to ``False``. + Whether RANSAC should predict signals for chunks of channels over the + entire signal length ("channel-wise RANSAC", see `max_chunk_size` + parameter). If ``False``, RANSAC will instead predict signals for all + channels at once but over a number of smaller time windows instead of + over the entire signal length ("window-wise RANSAC"). Channel-wise + RANSAC generally has higher RAM demands than window-wise RANSAC + (especially if `max_chunk_size` is ``None``), but can be faster on + systems with lots of RAM to spare. Has no effect if not using RANSAC. + Defaults to ``False``. max_chunk_size : {int, None}, optional The maximum number of channels to predict at once during channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk size that will