Skip to content

Commit

Permalink
Draft of window-wise RANSAC
Browse files Browse the repository at this point in the history
  • Loading branch information
a-hurst committed Apr 16, 2021
1 parent bd2d799 commit b9ce469
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 6 deletions.
2 changes: 2 additions & 0 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def find_bad_by_ransac(
fraction_bad=0.4,
corr_window_secs=5.0,
channel_wise=False,
window_wise=False
):
"""Detect channels that are not predicted well by other channels.
Expand Down Expand Up @@ -439,5 +440,6 @@ def find_bad_by_ransac(
fraction_bad,
corr_window_secs,
channel_wise,
window_wise,
self.random_state,
)
150 changes: 144 additions & 6 deletions pyprep/ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from mne.channels.interpolation import _make_interpolation_matrix
from mne.utils import check_random_state

from pyprep.utils import split_list, verify_free_ram, _get_random_subset
from pyprep.utils import (
split_list, verify_free_ram, _get_random_subset, _mat_round,
_correlate_arrays
)


def find_bad_by_ransac(
Expand All @@ -20,6 +23,7 @@ def find_bad_by_ransac(
fraction_bad=0.4,
corr_window_secs=5.0,
channel_wise=False,
window_wise=False,
random_state=None,
):
"""Detect channels that are not predicted well by other channels.
Expand Down Expand Up @@ -127,7 +131,7 @@ def find_bad_by_ransac(
random_ch_picks.append(picks)

# Correlation windows setup
correlation_frames = corr_window_secs * sample_rate
correlation_frames = int(corr_window_secs * sample_rate)
correlation_window = np.arange(correlation_frames)
n = correlation_window.shape[0]
correlation_offsets = np.arange(
Expand Down Expand Up @@ -159,7 +163,7 @@ def find_bad_by_ransac(

if channel_wise:
chunk_size = 1
while mem_error:
while mem_error and not window_wise:
try:
channel_chunks = split_list(job, chunk_size)
total_chunks = len(channel_chunks)
Expand Down Expand Up @@ -198,9 +202,20 @@ def find_bad_by_ransac(
"ested samples."
)

# Thresholding
thresholded_correlations = channel_correlations < corr_thresh
frac_bad_corr_windows = np.mean(thresholded_correlations, axis=0)
if window_wise:
# Get correlations between actual vs predicted signals for each RANSAC window
interp_mats = _make_interpolation_matrices(random_ch_picks, chn_pos_good)
channel_correlations = _ransac_correlations_alt(
data[good_idx, :], interp_mats, correlation_frames, w_correlation
)
# Calculate fractions of bad RANSAC windows for each channel
thresholded_correlations = channel_correlations < corr_thresh
frac_bad_corr_windows = np.zeros(n_chans)
frac_bad_corr_windows[good_idx] = np.mean(thresholded_correlations, axis=0)
else:
# Calculate fractions of bad RANSAC windows for each channel
thresholded_correlations = channel_correlations < corr_thresh
frac_bad_corr_windows = np.mean(thresholded_correlations, axis=0)

# find the corresponding channel names and return
bad_ransac_channels_idx = np.argwhere(frac_bad_corr_windows > fraction_bad)
Expand Down Expand Up @@ -396,3 +411,126 @@ def _get_ransac_pred(
interpol_mat = _make_interpolation_matrix(reconstr_pos, chn_pos)
ransac_pred = np.matmul(interpol_mat, data[reconstr_picks, :])
return ransac_pred


def _ransac_correlations_alt(good_data, interpolation_mats, win_size, win_count):
"""Calculate correlations of channels with their RANSAC-predicted values.
Parameters
----------
good_data : np.ndarray
A 2-D array containing the EEG signals from currently-good channels.
interpolation_mats : list of np.ndarray
A list of interpolation matrices, one for each RANSAC sample of channels.
win_size : int
Number of frames/samples of EEG data in each RANSAC correlation window.
win_count: int
Number of RANSAC correlation windows.
Returns
-------
channel_correlations : np.ndarray
Correlations of the given channels to their predicted values within each
RANSAC window.
"""
chn_count = good_data.shape[0]
channel_correlations = np.ones((win_count, chn_count))

for window in range(win_count):

# Get the current window of EEG data
start = window * win_size
end = (window + 1) * win_size
actual = good_data[:, start:end]

# Get the median RANSAC-predicted signal for each channel
predicted = _predict_median_signals(actual, interpolation_mats, True)

# Calculate the actual vs predicted signal correlation for each channel
channel_correlations[window, :] = _correlate_arrays(actual, predicted)

return channel_correlations


def _make_interpolation_matrices(random_ch_picks, chn_pos_good):
"""Create an interpolation matrix for each RANSAC sample of channels.
This function takes the spatial coordinates of random subsets of currently-good
channels and uses them to predict what the signal will be at the spatial
coordinates of all other currently-good channels. The results of this process are
returned as matrices that can be multiplied with EEG data to generate predicted
signals.
Parameters
----------
random_ch_picks : list of list of int
A list containing multiple random subsets of currently-good channels.
chn_pos_good : np.ndarray
3-D spatial coordinates of all currently-good channels.
Returns
-------
interpolation_mats : list of np.ndarray
A list of interpolation matrices, one for each random subset of channels.
Each matrix has the shape `[num_good_channels, num_good_channels]`, with the
number of good channels being inferred from the size of `ch_pos_good`.
Notes
-----
This function currently makes use of a private MNE function,
``mne.channels.interpolation._make_interpolation_matrix``, to generate matrices.
"""
n_chans_good = chn_pos_good.shape[0]

interpolation_mats = []
for sample in random_ch_picks:
mat = np.zeros((n_chans_good, n_chans_good))
subset_pos = chn_pos_good[sample, :]
mat[:, sample] = _make_interpolation_matrix(subset_pos, chn_pos_good)
interpolation_mats.append(mat)

return interpolation_mats


def _predict_median_signals(window, interpolation_mats, matlab_strict=False):
"""Calculate the median RANSAC-predicted signal for a given window of data.
Parameters
----------
window : np.ndarray
A 2-D window of EEG data with the shape `[channels, samples]`.
interpolation_mats : list of np.ndarray
A set of channel interpolation matrices, one for each RANSAC sample of
channels.
matlab_strict : bool, optional
Whether MATLAB PREP's internal logic should be strictly followed (see Notes).
Defaults to False.
Returns
-------
predicted : np.ndarray
The median RANSAC-predicted EEG signal for the given window of data.
Notes
-----
In MATLAB PREP, the median signal is calculated by sorting the different
predictions for each EEG sample/channel from low to high and then taking the value
at the middle index (as calculated by `int(n_ransac_samples / 2.0)`) for each.
Because this logic only returns the correct result for odd numbers of samples, the
current function will instead return the true median signal across predictions
unless strict MATLAB equivalence is requested.
"""
ransac_samples = len(interpolation_mats)
merged_mats = np.concatenate(interpolation_mats, axis=0)

predictions_per_sample = np.reshape(
np.matmul(merged_mats, window),
(ransac_samples, window.shape[0], window.shape[1])
)

if matlab_strict:
# Match MATLAB's rounding logic (.5 always rounded up)
median_idx = int(_mat_round(ransac_samples / 2.0) - 1)
return np.sort(predictions_per_sample, axis=0)[median_idx, :, :]
else:
return np.median(predictions_per_sample, axis=0)
10 changes: 10 additions & 0 deletions tests/test_find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ def test_findnoisychannels(raw, montage):
bads = nd.bad_by_ransac
assert bads == raw_tmp.ch_names[0:6]

# Test for finding bad channels by window-wise RANSAC
raw_tmp = raw.copy()
# Ransac identifies channels that go bad together and are highly correlated.
# Inserting highly correlated signal in channels 0 through 3 at 30 Hz
raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_bad_by_ransac(window_wise=True)
bads = nd.bad_by_ransac
assert bads == raw_tmp.ch_names[0:6]

# Test for finding bad channels by channel-wise RANSAC
raw_tmp = raw.copy()
# Ransac identifies channels that go bad together and are highly correlated.
Expand Down

0 comments on commit b9ce469

Please sign in to comment.