Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Split PrepPipeline into separate methods, make final interpolation optional #99

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
16 changes: 10 additions & 6 deletions examples/run_full_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import matplotlib.pyplot as plt

from pyprep.prep_pipeline import PrepPipeline
from pyprep.removeTrend import removeTrend

###############################################################################
# Let's download some data for testing. Picking the 1st run of subject 4 here.
Expand Down Expand Up @@ -104,9 +105,12 @@
#
# You can check the detected bad channels in each step of PREP.

original_bads = prep.bad_channels["original"]
post_interp_bads = prep.bad_channels["post-interpolation"]

print("Bad channels: {}".format(prep.interpolated_channels))
print("Bad channels original: {}".format(prep.noisy_channels_original["bad_all"]))
print("Bad channels after interpolation: {}".format(prep.still_noisy_channels))
print("Bad channels original: {}".format(original_bads))
print("Bad channels after interpolation: {}".format(post_interp_bads))

# Matlab's results
# ----------------
Expand Down Expand Up @@ -169,7 +173,7 @@

EEG_new_matlab = sio.loadmat(fname_mat2)
EEG_new_matlab = EEG_new_matlab["save_data"]
EEG_new = prep.EEG_new
EEG_new = removeTrend(prep.EEG_raw, sample_rate=prep.sfreq) * 1e6
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
EEG_new_max = np.max(abs(EEG_new), axis=None)
EEG_new_diff = EEG_new - EEG_new_matlab
EEG_new_mse = ((EEG_new_diff / EEG_new_max) ** 2).mean(axis=None)
Expand Down Expand Up @@ -202,7 +206,7 @@

EEG_clean_matlab = sio.loadmat(fname_mat3)
EEG_clean_matlab = EEG_clean_matlab["save_data"]
EEG_clean = prep.EEG
EEG_clean = prep.EEG_filtered * 1e6
EEG_max = np.max(abs(EEG_clean), axis=None)
EEG_diff = EEG_clean - EEG_clean_matlab
EEG_mse = ((EEG_diff / EEG_max) ** 2).mean(axis=None)
Expand Down Expand Up @@ -233,14 +237,14 @@
axs[2, 1].set_title("Line-noise removed EEG", fontsize=14)
axs[2, 0].set_ylabel("Channel Number", fontsize=14)

EEG = prep.EEG_before_interpolation
EEG = prep.EEG_post_reference * 1e6
EEG_max = np.max(abs(EEG), axis=None)
EEG_ref_mat = sio.loadmat(fname_mat4)
EEG_ref_matlab = EEG_ref_mat["save_EEG"]
reference_matlab = EEG_ref_mat["save_reference"]
EEG_ref_diff = EEG - EEG_ref_matlab
EEG_ref_mse = ((EEG_ref_diff / EEG_max) ** 2).mean(axis=None)
reference_signal = prep.reference_before_interpolation
reference_signal = prep.robust_reference_signal * 1e6
reference_max = np.max(abs(reference_signal), axis=None)
reference_diff = reference_signal - reference_matlab
reference_mse = ((reference_diff / reference_max) ** 2).mean(axis=None)
Expand Down
228 changes: 172 additions & 56 deletions pyprep/prep_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import mne
from mne.utils import check_random_state

from pyprep.find_noisy_channels import NoisyChannels
from pyprep.reference import Reference
from pyprep.removeTrend import removeTrend
from pyprep.utils import _set_diff, _union # noqa: F401


class PrepPipeline:
Expand Down Expand Up @@ -166,75 +164,193 @@ def __init__(
self.filter_kwargs = filter_kwargs
self.matlab_strict = matlab_strict

# Initialize attributes to be filled in later
self.EEG_raw = self.raw_eeg.get_data()
self.EEG_filtered = None
self.EEG_post_reference = None

# NOTE: 'original' refers to before initial average reference, not first
# pass afterwards. Not necessarily comparable to later values?
self.noisy_info = {
"original": None, "post-reference": None, "post-interpolation": None
}
self.bad_channels = {
"original": None, "post-reference": None, "post-interpolation": None
}
self.interpolated_channels = None
self.robust_reference_signal = None
self._interpolated_reference_signal = None

@property
def raw(self):
"""Return a version of self.raw_eeg that includes the non-eeg channels."""
full_raw = self.raw_eeg.copy()
if self.raw_non_eeg is None:
return full_raw
else:
return full_raw.add_channels([self.raw_non_eeg], force_update_info=True)
if self.raw_non_eeg is not None:
full_raw.add_channels([self.raw_non_eeg], force_update_info=True)
return full_raw

def fit(self):
"""Run the whole PREP pipeline."""
noisy_detector = NoisyChannels(self.raw_eeg, random_state=self.random_state)
noisy_detector.find_bad_by_nan_flat()
# unusable_channels = _union(
# noisy_detector.bad_by_nan, noisy_detector.bad_by_flat
# )
# reference_channels = _set_diff(self.prep_params["ref_chs"], unusable_channels)
# Step 1: 1Hz high pass filtering
if len(self.prep_params["line_freqs"]) != 0:
self.EEG_new = removeTrend(
self.EEG_raw, self.sfreq, matlab_strict=self.matlab_strict
@property
def current_noisy_info(self):
post_ref = self.noisy_info["post-reference"]
post_interp = self.noisy_info["post-interpolation"]
return post_interp if post_interp else post_ref

@property
def remaining_bad_channels(self):
post_ref = self.bad_channels["post-reference"]
post_interp = self.bad_channels["post-interpolation"]
return post_interp if post_interp else post_ref

@property
def current_reference_signal(self):
post_ref = self.robust_reference_signal
post_interp = self._interpolated_reference_signal
return post_interp if post_interp else post_ref

def get_raw(self, stage=None):
"""Retrieve the full recording data at a given stage of the pipeline.

Valid pipeline stages include 'unprocessed' (the raw data prior to running
the pipeline), 'filtered' (the data following adaptive line noise
removal), 'post-reference' (the data after robust referencing, prior to any
bad channel interpolation), and 'post-interpolation' (the data after robust
referencing and bad channel interpolation).

Parameters
----------
stage : str, optional
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd do something like this:

Suggested change
stage : str, optional
stage : {"unprocessed", "filtered", "post-reference", "post-interpolation"}, optional

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That causes the line to go beyond 88 characters, is line wrap for argument types something that's supported by Numpy docstyle?

The stage of the pipeline for which the full data will be retrieved. If
not specified, the current state of the data will be retrieved.

Returns
-------
full_raw: mne.io.Raw
An MNE Raw object containing the EEG data for the given stage of the
pipeline, along with any non-EEG channels that were present in the
original input data.

"""
interpolated = self.interpolated_channels is not None
stages = {
"unprocessed": self.EEG_raw,
"filtered": self.EEG_filtered,
"post-reference": self.EEG_post_reference,
"post-interpolation": self.raw_eeg._data if interpolated else None,
}
if stage is not None and stage.lower() not in stages.keys():
raise ValueError(
"'{stage}' is not a valid pipeline stage. Valid stages are "
"'unprocessed', 'filtered', 'post-reference', and 'post-interpolation'."
)

# Step 2: Removing line noise
linenoise = self.prep_params["line_freqs"]
if self.filter_kwargs is None:
self.EEG_clean = mne.filter.notch_filter(
self.EEG_new,
Fs=self.sfreq,
freqs=linenoise,
method="spectrum_fit",
mt_bandwidth=2,
p_value=0.01,
filter_length="10s",
)
else:
self.EEG_clean = mne.filter.notch_filter(
self.EEG_new,
Fs=self.sfreq,
freqs=linenoise,
**self.filter_kwargs,
eeg_data = self.raw_eeg._data # Default to most recent stage of pipeline
if stage:
eeg_data = stages[stage.lower()]
if not eeg_data:
raise ValueError(
"Could not retrieve {stage} data, as that stage of the pipeline "
"has not yet been performed."
Comment on lines +251 to +252
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Could not retrieve {stage} data, as that stage of the pipeline "
"has not yet been performed."
f"Could not retrieve {stage} data, as that stage of the pipeline "
"has not yet been performed."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, nice catch!

)
full_raw = self.raw_eeg.copy()
full_raw._data = eeg_data
if self.raw_non_eeg is not None:
full_raw.add_channels([self.raw_non_eeg])

# Add Trend back
self.EEG = self.EEG_raw - self.EEG_new + self.EEG_clean
self.raw_eeg._data = self.EEG
return full_raw

# Step 3: Referencing
reference = Reference(
def remove_line_noise(self, line_freqs):
"""Remove line noise from all EEG channels using multi-taper decomposition.

This filtering method attempts to isolate and remove line noise from the
signal while preserving unrelated background signal in the same frequency
ranges. This is done to minimize distortions in the power-spectral density
curves due to line noise removal.

Parameters
----------
line_freqs: {np.ndarray, list}
A list of the frequencies (in Hz) at which line noise should be removed
(e.g., ``np.arange(60, sfreq / 2, 60)`` for a recording with a powerline
noise of 60 Hz).

"""
# Define default settings for filter and apply any kwargs overrides
settings = {"mt_bandwidth": 2, "p_value": 0.01, "filter_length": "10s"}
if isinstance(self.filter_kwargs, dict):
settings.update(self.filter_kwargs)

# Remove slow drifts from the recording prior to filtering
eeg_detrended = removeTrend(
self.EEG_raw, self.sfreq, matlab_strict=self.matlab_strict
)

# Remove line noise and add the removed slow drifts back
eeg_cleaned = mne.filter.notch_filter(
eeg_detrended,
Fs=self.sfreq,
freqs=line_freqs,
method="spectrum_fit",
**settings,
# Add support for parallel jobs if joblib installed?
)
self.EEG_filtered = (self.EEG_raw - eeg_detrended) + eeg_cleaned
self.raw_eeg._data = self.EEG_filtered

def robust_reference(self, max_iterations=4, interpolate_bads=True):
"""Perform robust referencing on the EEG signal and detect bad channels.

This method uses an iterative approach to estimate a robust average
reference signal free of contamination from bad channels, as detected
automatically using the methods of :class:`~pyprep.NoisyChannels`. Once
estimated, the robust average reference is applied to the data and bad
channel detection is re-run to flag any noisy or unusable channels
post-reference.

By default, this method will also interpolate the signals of any channels
detected as bad following robust referencing, re-reference the data
accordingly, and re-detect any remaining bad channels.

Parameters
----------
max_iterations : int, optional
The maximum number of iterations of noisy channel removal to perform
during robust referencing. Defaults to ``4``.
interpolate_bads : bool, optional
Whether or not any remaining bad channels following robust referencing
should be interpolated. Defaults to ``True``.

"""
# Perform robust referencing on the signal
ref = Reference(
self.raw_eeg,
self.prep_params,
random_state=self.random_state,
matlab_strict=self.matlab_strict,
**self.ransac_settings,
)
reference.perform_reference(self.prep_params["max_iterations"])
self.raw_eeg = reference.raw
self.noisy_channels_original = reference.noisy_channels_original
self.noisy_channels_before_interpolation = (
reference.noisy_channels_before_interpolation
)
self.noisy_channels_after_interpolation = (
reference.noisy_channels_after_interpolation
)
self.bad_before_interpolation = reference.bad_before_interpolation
self.EEG_before_interpolation = reference.EEG_before_interpolation
self.reference_before_interpolation = reference.reference_signal
self.reference_after_interpolation = reference.reference_signal_new
self.interpolated_channels = reference.interpolated_channels
self.still_noisy_channels = reference.still_noisy_channels
ref.perform_reference(max_iterations, interpolate_bads)

self.raw_eeg = ref.raw
self.EEG_post_reference = ref.EEG_before_interpolation
self.robust_reference_signal = ref.reference_signal
self._interpolated_reference_signal = ref.reference_signal_new

self.noisy_info["original"] = ref.noisy_channels_original
self.noisy_info["post-reference"] = ref.noisy_channels_before_interpolation
self.noisy_info["post-interpolation"] = ref.noisy_channels_after_interpolation

self.bad_channels["original"] = ref.noisy_channels_original["bad_all"]
self.bad_channels["post-reference"] = ref.bad_before_interpolation
self.bad_channels["post-interpolation"] = ref.still_noisy_channels
self.interpolated_channels = ref.interpolated_channels

def fit(self):
"""Run the whole PREP pipeline."""
# Step 1: Adaptive line noise removal
if len(self.prep_params["line_freqs"]) != 0:
self.remove_line_noise(self.prep_params["line_freqs"])

# Step 2: Robust Referencing
self.robust_reference(self.prep_params["max_iterations"])

return self
Loading