Skip to content

Commit

Permalink
Functional ransac (#51)
Browse files Browse the repository at this point in the history
* function version of ransac

* minor doc fix

* flake 8 corrections

* Apply suggestions from code review

Co-authored-by: Stefan Appelhoff <[email protected]>

* add ransac module to API docs

* fix directives

* suggested changes

* corrected find_bad_by_ransac documentation

Co-authored-by: Stefan Appelhoff <[email protected]>
  • Loading branch information
yjmantilla and sappelhoff authored Feb 6, 2021
1 parent bfb2ab1 commit 8476ea5
Show file tree
Hide file tree
Showing 5 changed files with 438 additions and 285 deletions.
14 changes: 14 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,17 @@ The :class:`PrepPipeline` class
:toctree: generated/

PrepPipeline

The :mod:`ransac` module
===============================

.. automodule:: pyprep.ransac
:no-members:
:no-inherited-members:

.. currentmodule:: ransac

.. autosummary::
:toctree: generated/

find_bad_by_ransac
3 changes: 3 additions & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Current

Changelog
~~~~~~~~~
- Created a new module named :mod:`ransac` which contains :func:`find_bad_by_ransac <ransac.find_bad_by_ransac>`, a standalone function mirroring the previous ransac method from the :class:`NoisyChannels` class, by `Yorguin Mantilla`_ (:gh:`51`)
- Added two attributes :attr:`PrepPipeline.noisy_channels_before_interpolation <prep_pipeline.PrepPipeline>` and :attr:`PrepPipeline.noisy_channels_after_interpolation <prep_pipeline.PrepPipeline>` which have the detailed output of each noisy criteria, by `Yorguin Mantilla`_ (:gh:`45`)
- Added two keys to the :attr:`PrepPipeline.noisy_channels_original <prep_pipeline.PrepPipeline>` dictionary: ``bad_by_dropout`` and ``bad_by_SNR``, by `Yorguin Mantilla`_ (:gh:`45`)
- Changed RANSAC chunking logic to reduce max memory use and prefer equal chunk sizes where possible, by `Austin Hurst`_ (:gh:`44`)
Expand All @@ -43,6 +44,8 @@ Bug

API
~~~
- The permissible parameters for the following methods were removed and/or reordered: :func:`ransac.ransac_correlations`, :func:`ransac.run_ransac`, and :func:`ransac.get_ransac_pred` methods, by `Yorguin Mantilla`_ (:gh:`51`)
- The following methods have been moved to a new module named :mod:`ransac` and are now private: :meth:`NoisyChannels.ransac_correlations`, :meth:`NoisyChannels.run_ransac <find_noisy_channels.NoisyChannels.run_ransac>`, and :meth:`NoisyChannels.get_ransac_pred <find_noisy_channels.NoisyChannels.get_ransac_pred>` methods, by `Yorguin Mantilla`_ (:gh:`51`)
- The permissible parameters for the following methods were removed and/or reordered: :meth:`NoisyChannels.ransac_correlations <find_noisy_channels.NoisyChannels.ransac_correlations>`, :meth:`NoisyChannels.run_ransac`, and :meth:`NoisyChannels.get_ransac_pred <find_noisy_channels.NoisyChannels.get_ransac_pred>` methods, by `Austin Hurst`_ and `Yorguin Mantilla`_ (:gh:`43`)

.. _changes_0_3_1:
Expand Down
1 change: 1 addition & 0 deletions pyprep/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""initialize pyprep."""
import pyprep.ransac as ransac # noqa: F401
from pyprep.find_noisy_channels import NoisyChannels # noqa: F401
from pyprep.prep_pipeline import PrepPipeline # noqa: F401

Expand Down
311 changes: 26 additions & 285 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""finds bad channels."""
import mne
import numpy as np
from mne.channels.interpolation import _make_interpolation_matrix
from mne.utils import check_random_state
from scipy import signal
from scipy.stats import iqr
from statsmodels import robust

from pyprep.ransac import find_bad_by_ransac
from pyprep.removeTrend import removeTrend
from pyprep.utils import filter_design, split_list, verify_free_ram
from pyprep.utils import filter_design


class NoisyChannels:
Expand Down Expand Up @@ -67,7 +67,6 @@ def __init__(self, raw, do_detrend=True, random_state=None):

# random_state
self.random_state = check_random_state(random_state)
self.random_ch_picks = [] # needed for ransac

# The identified bad channels
self.bad_by_nan = []
Expand Down Expand Up @@ -379,6 +378,9 @@ def find_bad_by_ransac(
):
"""Detect channels that are not predicted well by other channels.
This method is a wrapper of the ``find_bad_by_ransac`` function
from the ``ransac`` module.
Here, a ransac approach (see [1]_, and a short discussion in [2]_) is
adopted to predict a "clean EEG" dataset. After identifying clean EEG
channels through the other methods, the clean EEG dataset is
Expand Down Expand Up @@ -406,292 +408,31 @@ def find_bad_by_ransac(
channel as `bad_by_ransac`.
corr_window_secs : float
Size of the correlation window in seconds.
channel_wise : bool
If True the ransac will be done 1 channel at a time, if false
it will be done as fast as possible (more channels at a time).
References
----------
.. [1] Fischler, M.A., Bolles, R.C. (1981). Random rample consensus: A
Paradigm for Model Fitting with Applications to Image Analysis and
Automated Cartography. Communications of the ACM, 24, 381-395
Paradigm for Model Fitting with Applications to Image Analysis and
Automated Cartography. Communications of the ACM, 24, 381-395
.. [2] Jas, M., Engemann, D.A., Bekhti, Y., Raimondo, F., Gramfort, A.
(2017). Autoreject: Automated Artifact Rejection for MEG and EEG
Data. NeuroImage, 159, 417-429
(2017). Autoreject: Automated Artifact Rejection for MEG and EEG
Data. NeuroImage, 159, 417-429
"""
# First, check that the argument types are valid
if type(n_samples) != int:
err = "Argument 'n_samples' must be an int (got {0})"
raise TypeError(err.format(type(n_samples).__name__))

# Then, identify all bad channels by other means:
bads = self.get_bads()

# Get all channel positions and the position subset of "clean channels"
good_idx = mne.pick_channels(list(self.ch_names_new), include=[], exclude=bads)
good_chn_labs = self.ch_names_new[good_idx]
n_chans_good = good_idx.shape[0]
chn_pos = self.raw_mne._get_channel_positions()
chn_pos_good = chn_pos[good_idx, :]

# Check if we have enough remaining channels
# after exclusion of bad channels
n_pred_chns = int(np.ceil(fraction_good * n_chans_good))

if n_pred_chns <= 3:
raise IOError(
"Too few channels available to reliably perform"
" ransac. Perhaps, too many channels have failed"
" quality tests."
)

# Before running, make sure we have enough memory when using the
# smallest possible chunk size
verify_free_ram(self.EEGData, n_samples, 1)

# Generate random channel picks for each RANSAC sample
self.random_ch_picks = []
good_chans = np.arange(chn_pos_good.shape[0])
rng = check_random_state(self.random_state)
for i in range(n_samples):
# Pick a random subset of clean channels to use for interpolation
picks = rng.choice(good_chans, size=n_pred_chns, replace=False)
self.random_ch_picks.append(picks)

# Correlation windows setup
correlation_frames = corr_window_secs * self.sample_rate
correlation_window = np.arange(correlation_frames)
n = correlation_window.shape[0]
correlation_offsets = np.arange(
0, (self.signal_len - correlation_frames), correlation_frames
self.bad_by_ransac, _ = find_bad_by_ransac(
self.EEGData,
self.sample_rate,
self.signal_len,
self.ch_names_new,
self.raw_mne._get_channel_positions(),
self.get_bads(),
n_samples,
fraction_good,
corr_thresh,
fraction_bad,
corr_window_secs,
channel_wise,
self.random_state,
)
w_correlation = correlation_offsets.shape[0]

# Preallocate
channel_correlations = np.ones((w_correlation, self.n_chans_new))
# Notice self.EEGData.shape[0] = self.n_chans_new
# They came from the same drop of channels

print("Executing RANSAC\nThis may take a while, so be patient...")

# Calculate smallest chunk size for each possible chunk count
chunk_sizes = []
chunk_count = 0
for i in range(1, self.n_chans_new + 1):
n_chunks = int(np.ceil(self.n_chans_new / i))
if n_chunks != chunk_count:
chunk_count = n_chunks
chunk_sizes.append(i)

chunk_size = chunk_sizes.pop()
mem_error = True
job = list(range(self.n_chans_new))

if channel_wise:
chunk_size = 1
while mem_error:
try:
channel_chunks = split_list(job, chunk_size)
total_chunks = len(channel_chunks)
current = 1
for chunk in channel_chunks:
channel_correlations[:, chunk] = self.ransac_correlations(
chunk,
chn_pos,
chn_pos_good,
good_chn_labs,
self.EEGData,
n_samples,
n,
w_correlation,
)
if chunk == channel_chunks[0]:
# If it gets here, it means it is the optimal
print("Finding optimal chunk size :", chunk_size)
print("Total # of chunks:", total_chunks)
print("Current chunk:", end=" ", flush=True)

print(current, end=" ", flush=True)
current = current + 1

mem_error = False # All chunks processed, hurray!
del current
except MemoryError:
if len(chunk_sizes):
chunk_size = chunk_sizes.pop()
else: # pragma: no cover
raise MemoryError(
"Not even doing 1 channel at a time the data fits in ram..."
"You could downsample the data or reduce the number of requ"
"ested samples."
)

# Thresholding
thresholded_correlations = channel_correlations < corr_thresh
frac_bad_corr_windows = np.mean(thresholded_correlations, axis=0)

# find the corresponding channel names and return
bad_ransac_channels_idx = np.argwhere(frac_bad_corr_windows > fraction_bad)
bad_ransac_channels_name = self.ch_names_new[
bad_ransac_channels_idx.astype(int)
]
self.bad_by_ransac = [i[0] for i in bad_ransac_channels_name]
print("\nRANSAC done!")

def run_ransac(self, chn_pos, chn_pos_good, good_chn_labs, data, n_samples):
"""Detect noisy channels apart from the ones described previously.
It creates a random subset of the so-far good channels
and predicts the values of the channels not in the subset.
Parameters
----------
chn_pos : np.ndarray
3-D coordinates of the electrode position
chn_pos_good : np.ndarray
3-D coordinates of all the channels not detected noisy so far
good_chn_labs : np.ndarray | list
channel labels for the ch_pos_good channels-
data : np.ndarray
2-D EEG data
n_samples : int
number of interpolations from which a median will be computed
Returns
-------
ransac_eeg : np.ndarray
The EEG data predicted by RANSAC
"""
# n_chns, n_timepts = data.shape
# 2 next lines should be equivalent but support single channel processing
n_timepts = data.shape[1]
n_chns = chn_pos.shape[0]

# Before running, make sure we have enough memory
verify_free_ram(data, n_samples, n_chns)

# Memory seems to be fine ...
# Make the predictions
eeg_predictions = np.zeros((n_chns, n_timepts, n_samples))
for sample in range(n_samples):
eeg_predictions[..., sample] = self.get_ransac_pred(
chn_pos, chn_pos_good, good_chn_labs, data, sample
)

# Form median from all predictions
ransac_eeg = np.median(eeg_predictions, axis=-1, overwrite_input=True)
return ransac_eeg

def get_ransac_pred(self, chn_pos, chn_pos_good, good_chn_labs, data, sample):
"""Perform RANSAC prediction.
Parameters
----------
chn_pos : np.ndarray
3-D coordinates of the electrode position
chn_pos_good : np.ndarray
3-D coordinates of all the channels not detected noisy so far
good_chn_labs : np.ndarray | list
channel labels for the ch_pos_good channels
data : np.ndarray
2-D EEG data
sample : int
the current RANSAC sample number
Returns
-------
ransac_pred : np.ndarray
Single RANSAC prediction
"""
# Get the random channel selection for the current sample
reconstr_idx = self.random_ch_picks[sample]

# Get positions and according labels
reconstr_labels = good_chn_labs[reconstr_idx]
reconstr_pos = chn_pos_good[reconstr_idx, :]

# Map the labels to their indices within the complete data
# Do not use mne.pick_channels, because it will return a sorted list.
reconstr_picks = [
list(self.ch_names_new).index(chn_lab) for chn_lab in reconstr_labels
]

# Interpolate
interpol_mat = _make_interpolation_matrix(reconstr_pos, chn_pos)
ransac_pred = np.matmul(interpol_mat, data[reconstr_picks, :])
return ransac_pred

def ransac_correlations(
self,
chans_to_predict,
chn_pos,
chn_pos_good,
good_chn_labs,
data,
n_samples,
n,
w_correlation,
):
"""Get correlations of channels to their ransac predicted values.
Parameters
----------
chans_to_predict: list of int
Indexes of the channels to predict.
chn_pos : np.ndarray
3-D coordinates of the electrode positions to predict
chn_pos_good : np.ndarray
3-D coordinates of all the channels not detected noisy so far
good_chn_labs : np.ndarray | list
channel labels for the ch_pos_good channels
data : np.ndarray
2-D EEG data
n_samples : int
Number of samples used for computation of ransac.
n : int
Number of frames/samples of each window.
w_correlation: int
Number of windows.
Returns
-------
channel_correlations : np.ndarray
correlations of the given channels to their ransac predicted values.
"""
# Preallocate
channel_correlations = np.ones((w_correlation, len(chans_to_predict)))

# Make the ransac predictions
ransac_eeg = self.run_ransac(
chn_pos=chn_pos[chans_to_predict, :],
chn_pos_good=chn_pos_good,
good_chn_labs=good_chn_labs,
data=data,
n_samples=n_samples,
)

# Correlate ransac prediction and eeg data

# For the actual data
data_window = data[chans_to_predict, : n * w_correlation]
data_window = data_window.reshape(len(chans_to_predict), n, w_correlation)

# For the ransac predicted eeg
pred_window = ransac_eeg[: len(chans_to_predict), : n * w_correlation]
pred_window = pred_window.reshape(len(chans_to_predict), n, w_correlation)

# Perform correlations
for k in range(w_correlation):
data_portion = data_window[:, :, k]
pred_portion = pred_window[:, :, k]

R = np.corrcoef(data_portion, pred_portion)

# Take only correlations of data with pred
# and use diag to extract correlation of
# data_i with pred_i
R = np.diag(R[0 : len(chans_to_predict), len(chans_to_predict) :])
channel_correlations[k, :] = R

return channel_correlations
Loading

0 comments on commit 8476ea5

Please sign in to comment.