Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add window-wise RANSAC to reduce RAM requirements #66

Merged
merged 10 commits into from
May 1, 2021
3 changes: 3 additions & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Changelog
- Changed RANSAC so that "bad by high-frequency noise" channels are retained when making channel predictions (provided they aren't flagged as bad by any other metric), matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`64`)
- Added a new flag ``matlab_strict`` to :class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, :class:`~pyprep.NoisyChannels`, and :func:`~pyprep.ransac.find_bad_by_ransac` for optionally matching MATLAB PREP's internal math as closely as possible, overriding areas where PyPREP attempts to improve on the original, by `Austin Hurst`_ (:gh:`70`)
- Added a ``matlab_strict`` method for high-pass trend removal, exactly matching MATLAB PREP's values if ``matlab_strict`` is enabled, by `Austin Hurst`_ (:gh:`71`)
- Added a window-wise implementaion of RANSAC and made it the default method, reducing the typical RAM demands of robust re-referencing considerably, by `Austin Hurst`_ (:gh:`66`)
- Added `max_chunk_size` parameter for specifying the maximum chunk size to use for channel-wise RANSAC, allowing more control over PyPREP RAM usage, by `Austin Hurst`_ (:gh:`66`)

Bug
~~~
Expand All @@ -55,6 +57,7 @@ API
- The permissible parameters for the following methods were removed and/or reordered: `ransac._ransac_correlations`, `ransac._run_ransac`, and `ransac._get_ransac_pred` methods, by `Yorguin Mantilla`_ (:gh:`51`)
- The following methods have been moved to a new module named :mod:`~pyprep.ransac` and are now private: `NoisyChannels.ransac_correlations`, `NoisyChannels.run_ransac`, and `NoisyChannels.get_ransac_pred` methods, by `Yorguin Mantilla`_ (:gh:`51`)
- The permissible parameters for the following methods were removed and/or reordered: `NoisyChannels.ransac_correlations`, `NoisyChannels.run_ransac`, and `NoisyChannels.get_ransac_pred` methods, by `Austin Hurst`_ and `Yorguin Mantilla`_ (:gh:`43`)
- Changed the meaning of the argument `channel_wise` in :meth:`~pyprep.NoisyChannels.find_bad_by_ransac` to mean 'perform RANSAC across chunks of channels instead of window-wise', from its original meaning of 'perform channel-wise RANSAC one channel at a time', by `Austin Hurst`_ (:gh:`66`)


.. _changes_0_3_1:
Expand Down
9 changes: 5 additions & 4 deletions examples/run_ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,15 @@
nd2 = NoisyChannels(raw)

###############################################################################
# Find all bad channels and print a summary
# Find all bad channels using channel-wise RANSAC and print a summary
start_time = perf_counter()
nd.find_bad_by_ransac()
nd.find_bad_by_ransac(channel_wise=True)
print("--- %s seconds ---" % (perf_counter() - start_time))

# Repeat RANSAC in a channel wise manner. This is slower but needs less memory.
# Repeat channel-wise RANSAC using a single channel at a time. This is slower
# but needs less memory.
start_time = perf_counter()
nd2.find_bad_by_ransac(channel_wise=True)
nd2.find_bad_by_ransac(channel_wise=True, max_chunk_size=1)
print("--- %s seconds ---" % (perf_counter() - start_time))

###############################################################################
Expand Down
51 changes: 43 additions & 8 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,22 +144,45 @@ def get_bads(self, verbose=False):
)
return bads

def find_all_bads(self, ransac=True):
def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None):
"""Call all the functions to detect bad channels.

This function calls all the bad-channel detecting functions.

Parameters
----------
ransac : bool
To detect channels by ransac or not.
ransac : bool, optional
Whether RANSAC should be used for bad channel detection, in addition
to the other methods. RANSAC can detect bad channels that other
methods are unable to catch, but also slows down noisy channel
detection considerably. Defaults to ``True``.
channel_wise : bool, optional
Whether RANSAC should predict signals for chunks of channels over the
entire signal length ("channel-wise RANSAC", see `max_chunk_size`
parameter). If ``False``, RANSAC will instead predict signals for all
channels at once but over a number of smaller time windows instead of
over the entire signal length ("window-wise RANSAC"). Channel-wise
RANSAC generally has higher RAM demands than window-wise RANSAC
(especially if `max_chunk_size` is ``None``), but can be faster on
systems with lots of RAM to spare. Has no effect if not using RANSAC.
Defaults to ``False``.
max_chunk_size : {int, None}, optional
The maximum number of channels to predict at once during
channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk
size that will fit into the available RAM, which may slow down
other programs on the host system. If using window-wise RANSAC
(the default) or not using RANSAC at all, this parameter has no
effect. Defaults to ``None``.

"""
self.find_bad_by_nan_flat()
self.find_bad_by_deviation()
self.find_bad_by_SNR()
if ransac:
self.find_bad_by_ransac()
self.find_bad_by_ransac(
channel_wise=channel_wise,
max_chunk_size=max_chunk_size
)

def find_bad_by_nan_flat(self):
"""Detect channels that appear flat or have NaN values."""
Expand Down Expand Up @@ -409,6 +432,7 @@ def find_bad_by_ransac(
fraction_bad=0.4,
corr_window_secs=5.0,
channel_wise=False,
max_chunk_size=None,
):
"""Detect channels that are not predicted well by other channels.

Expand Down Expand Up @@ -447,10 +471,20 @@ def find_bad_by_ransac(
The duration (in seconds) of each RANSAC correlation window. Defaults
to 5 seconds.
channel_wise : bool, optional
Whether RANSAC should be performed one channel at a time (lower RAM
demands) or in chunks of as many channels as can fit into the
currently available RAM (faster). Defaults to ``False`` (i.e., using
the faster method).
Whether RANSAC should predict signals for chunks of channels over the
entire signal length ("channel-wise RANSAC", see `max_chunk_size`
parameter). If ``False``, RANSAC will instead predict signals for all
channels at once but over a number of smaller time windows instead of
over the entire signal length ("window-wise RANSAC"). Channel-wise
RANSAC generally has higher RAM demands than window-wise RANSAC
(especially if `max_chunk_size` is ``None``), but can be faster on
systems with lots of RAM to spare. Defaults to ``False``.
max_chunk_size : {int, None}, optional
The maximum number of channels to predict at once during
channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk
size that will fit into the available RAM, which may slow down
other programs on the host system. If using window-wise RANSAC
(the default), this parameter has no effect. Defaults to ``None``.

References
----------
Expand Down Expand Up @@ -479,6 +513,7 @@ def find_bad_by_ransac(
fraction_bad,
corr_window_secs,
channel_wise,
max_chunk_size,
self.random_state,
self.matlab_strict,
)
Expand Down
28 changes: 25 additions & 3 deletions pyprep/prep_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@ class PrepPipeline:
ransac : bool, optional
Whether or not to use RANSAC for noisy channel detection in addition to
the other methods in :class:`~pyprep.NoisyChannels`. Defaults to True.
channel_wise : bool, optional
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether RANSAC should predict signals for chunks of channels over the
entire signal length ("channel-wise RANSAC", see `max_chunk_size`
parameter). If ``False``, RANSAC will instead predict signals for all
channels at once but over a number of smaller time windows instead of
over the entire signal length ("window-wise RANSAC"). Channel-wise
RANSAC generally has higher RAM demands than window-wise RANSAC
(especially if `max_chunk_size` is ``None``), but can be faster on
systems with lots of RAM to spare. Has no effect if not using RANSAC.
Defaults to ``False``.
max_chunk_size : {int, None}, optional
The maximum number of channels to predict at once during channel-wise
RANSAC. If ``None``, RANSAC will use the largest chunk size that will
fit into the available RAM, which may slow down other programs on the
host system. If using window-wise RANSAC (the default) or not using
RANSAC at all, this parameter has no effect. Defaults to ``None``.
random_state : {int, None, np.random.RandomState}, optional
The random seed at which to initialize the class. If random_state is
an int, it will be used as a seed for RandomState.
Expand Down Expand Up @@ -99,6 +115,8 @@ def __init__(
prep_params,
montage,
ransac=True,
channel_wise=False,
max_chunk_size=None,
random_state=None,
filter_kwargs=None,
matlab_strict=False,
Expand Down Expand Up @@ -133,7 +151,11 @@ def __init__(
if self.prep_params["reref_chs"] == "eeg":
self.prep_params["reref_chs"] = self.ch_names_eeg
self.sfreq = self.raw_eeg.info["sfreq"]
self.ransac = ransac
self.ransac_settings = {
'ransac': ransac,
'channel_wise': channel_wise,
'max_chunk_size': max_chunk_size
}
self.random_state = check_random_state(random_state)
self.filter_kwargs = filter_kwargs
self.matlab_strict = matlab_strict
Expand Down Expand Up @@ -189,9 +211,9 @@ def fit(self):
reference = Reference(
self.raw_eeg,
self.prep_params,
ransac=self.ransac,
random_state=self.random_state,
matlab_strict=self.matlab_strict
matlab_strict=self.matlab_strict,
**self.ransac_settings
)
reference.perform_reference()
self.raw_eeg = reference.raw
Expand Down
Loading