Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reference interpolates instead of deletes channels manually marked as bad, closes #146 #156

Merged
merged 12 commits into from
Sep 17, 2024
4 changes: 4 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ authors:
- given-names: Ayush
family-names: Agarwal
affiliation: 'Techno India University, India'
- given-names: John
family-names: Veillette
affiliation: 'Department of Psychology, University of Chicago, Chicago, IL, USA'
orcid: 'https://orcid.org/0000-0002-0332-4372'
type: software
repository-code: 'https://github.com/sappelhoff/pyprep'
license: MIT
Expand Down
1 change: 1 addition & 0 deletions docs/authors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
.. _Stefan Appelhoff: https://stefanappelhoff.com/
.. _Victor Xiang: https://github.com/Nick3151
.. _Yorguin Mantilla: https://github.com/yjmantilla
.. _John Veillette: https://github.com/john-veillette
2 changes: 1 addition & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Changelog

Bug
~~~
- nothing yet
- :class:`~pyprep.Reference` now keeps and interpolates channels channels manually marked as bad before PREP, by `John Veillette`_ (:gh:`146`)

Code health
~~~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False)

raw.load_data()
self.raw_mne = raw.copy()
self.bad_by_manual = raw.info["bads"]
self.raw_mne.pick_types(eeg=True)
self.sample_rate = raw.info["sfreq"]
if do_detrend:
Expand Down Expand Up @@ -162,6 +163,7 @@ def get_bads(self, verbose=False, as_dict=False):
"bad_by_SNR": self.bad_by_SNR,
"bad_by_dropout": self.bad_by_dropout,
"bad_by_ransac": self.bad_by_ransac,
"bad_by_manual": self.bad_by_manual,
}

all_bads = set()
Expand Down
10 changes: 7 additions & 3 deletions pyprep/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
raw.load_data()
self.raw = raw.copy()
self.ch_names = self.raw.ch_names
self.raw.pick_types(eeg=True, eog=False, meg=False)
self.raw.pick_types(eeg=True, eog=False, meg=False, exclude=[])
self.ch_names_eeg = self.raw.ch_names
self.EEG = self.raw.get_data()
self.reference_channels = params["ref_chs"]
Expand All @@ -97,6 +97,7 @@ def __init__(
self.random_state = check_random_state(random_state)
self._extra_info = {}
self.matlab_strict = matlab_strict
self.bads_manual = raw.info["bads"]

def perform_reference(self, max_iterations=4):
"""Estimate the true signal mean and interpolate bad channels.
Expand Down Expand Up @@ -149,6 +150,7 @@ def perform_reference(self, max_iterations=4):
self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
self.EEG_before_interpolation = self.EEG.copy()
self.noisy_channels_before_interpolation = noisy_detector.get_bads(as_dict=True)
self.noisy_channels_before_interpolation["bad_by_manual"] = self.bads_manual
self._extra_info["interpolated"] = noisy_detector._extra_info

bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
Expand Down Expand Up @@ -223,7 +225,7 @@ def robust_reference(self, max_iterations=4):
# Determine channels to use/exclude from initial reference estimation
self.unusable_channels = _union(
noisy_detector.bad_by_nan + noisy_detector.bad_by_flat,
noisy_detector.bad_by_SNR,
noisy_detector.bad_by_SNR + self.bads_manual,
)
reference_channels = _set_diff(self.reference_channels, self.unusable_channels)

Expand All @@ -237,6 +239,7 @@ def robust_reference(self, max_iterations=4):
"bad_by_SNR": [],
"bad_by_dropout": [],
"bad_by_ransac": [],
"bad_by_manual": self.bads_manual,
"bad_all": [],
}

Expand Down Expand Up @@ -282,7 +285,8 @@ def robust_reference(self, max_iterations=4):
noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type])
if bad_type not in ignore:
bad_chans.update(noisy[bad_type])
noisy["bad_all"] = list(bad_chans)
noisy["bad_by_manual"] = self.bads_manual
noisy["bad_all"] = list(bad_chans) + self.bads_manual
logger.info(f"Bad channels: {noisy}")

if (
Expand Down
14 changes: 14 additions & 0 deletions tests/test_find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ def test_bad_by_nan(raw_tmp):
assert nd.bad_by_nan == [raw_tmp.ch_names[nan_idx]]


def test_bad_by_manual(raw_tmp):
"""Test the detection of channels marked bad a priori."""
n_chans = raw_tmp.get_data().shape[0]
nan_idx = int(rng.integers(0, n_chans, 1)[0])
raw_tmp._data[nan_idx, 3] = np.nan
raw_tmp.info["bads"] = [raw_tmp.ch_names[0]]

# Test record of a priori bad channels on NoisyChannels init
nd = NoisyChannels(raw_tmp, do_detrend=False)
assert nd.bad_by_manual == [raw_tmp.ch_names[0]]
assert raw_tmp.ch_names[0] in nd.get_bads(as_dict=False)
raw_tmp.info["bads"] = []


def test_bad_by_flat(raw_tmp):
"""Test the detection of channels with flat or very weak signals."""
# Make the signal for a random channel extremely weak
Expand Down
1 change: 1 addition & 0 deletions tests/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_basic_input(raw, montage):
reference = Reference(raw_tmp, params, ransac=False)
reference.perform_reference()
assert isinstance(reference.noisy_channels, dict)
assert isinstance(reference.noisy_channels["bad_by_manual"], list)
assert isinstance(reference.noisy_channels_original, dict)
assert isinstance(reference.bad_before_interpolation, list)
assert isinstance(reference.reference_signal, np.ndarray)
Expand Down
Loading