Skip to content

Commit

Permalink
Reference interpolates instead of deletes channels manually marked as…
Browse files Browse the repository at this point in the history
… bad, closes #146 (#156)

* allows manually marked bad channels to be excluded from reference / interpolated, without being removed (which had been causing an error)

* adds new contributor to CITATION.cff

* adds new contributor ORCID

* Update authors.rst

* Update changelog.rst

* fixes some formatting issues flagged by CI

* adds manual bad channels to find_noisy_channels

* changes name of manual bads to match find_noisy_channels (which requires a specific format to work properly)

* adds test coverage

* fixes bug and formatting in test

* format
  • Loading branch information
john-veillette authored Sep 17, 2024
1 parent 06fb37b commit 89e6fc9
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 4 deletions.
4 changes: 4 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ authors:
- given-names: Ayush
family-names: Agarwal
affiliation: 'Techno India University, India'
- given-names: John
family-names: Veillette
affiliation: 'Department of Psychology, University of Chicago, Chicago, IL, USA'
orcid: 'https://orcid.org/0000-0002-0332-4372'
type: software
repository-code: 'https://github.com/sappelhoff/pyprep'
license: MIT
Expand Down
1 change: 1 addition & 0 deletions docs/authors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
.. _Stefan Appelhoff: https://stefanappelhoff.com/
.. _Victor Xiang: https://github.com/Nick3151
.. _Yorguin Mantilla: https://github.com/yjmantilla
.. _John Veillette: https://github.com/john-veillette
2 changes: 1 addition & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Changelog

Bug
~~~
- nothing yet
- :class:`~pyprep.Reference` now keeps and interpolates channels channels manually marked as bad before PREP, by `John Veillette`_ (:gh:`146`)

Code health
~~~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False)

raw.load_data()
self.raw_mne = raw.copy()
self.bad_by_manual = raw.info["bads"]
self.raw_mne.pick_types(eeg=True)
self.sample_rate = raw.info["sfreq"]
if do_detrend:
Expand Down Expand Up @@ -162,6 +163,7 @@ def get_bads(self, verbose=False, as_dict=False):
"bad_by_SNR": self.bad_by_SNR,
"bad_by_dropout": self.bad_by_dropout,
"bad_by_ransac": self.bad_by_ransac,
"bad_by_manual": self.bad_by_manual,
}

all_bads = set()
Expand Down
10 changes: 7 additions & 3 deletions pyprep/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
raw.load_data()
self.raw = raw.copy()
self.ch_names = self.raw.ch_names
self.raw.pick_types(eeg=True, eog=False, meg=False)
self.raw.pick_types(eeg=True, eog=False, meg=False, exclude=[])
self.ch_names_eeg = self.raw.ch_names
self.EEG = self.raw.get_data()
self.reference_channels = params["ref_chs"]
Expand All @@ -97,6 +97,7 @@ def __init__(
self.random_state = check_random_state(random_state)
self._extra_info = {}
self.matlab_strict = matlab_strict
self.bads_manual = raw.info["bads"]

def perform_reference(self, max_iterations=4):
"""Estimate the true signal mean and interpolate bad channels.
Expand Down Expand Up @@ -149,6 +150,7 @@ def perform_reference(self, max_iterations=4):
self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
self.EEG_before_interpolation = self.EEG.copy()
self.noisy_channels_before_interpolation = noisy_detector.get_bads(as_dict=True)
self.noisy_channels_before_interpolation["bad_by_manual"] = self.bads_manual
self._extra_info["interpolated"] = noisy_detector._extra_info

bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
Expand Down Expand Up @@ -223,7 +225,7 @@ def robust_reference(self, max_iterations=4):
# Determine channels to use/exclude from initial reference estimation
self.unusable_channels = _union(
noisy_detector.bad_by_nan + noisy_detector.bad_by_flat,
noisy_detector.bad_by_SNR,
noisy_detector.bad_by_SNR + self.bads_manual,
)
reference_channels = _set_diff(self.reference_channels, self.unusable_channels)

Expand All @@ -237,6 +239,7 @@ def robust_reference(self, max_iterations=4):
"bad_by_SNR": [],
"bad_by_dropout": [],
"bad_by_ransac": [],
"bad_by_manual": self.bads_manual,
"bad_all": [],
}

Expand Down Expand Up @@ -282,7 +285,8 @@ def robust_reference(self, max_iterations=4):
noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type])
if bad_type not in ignore:
bad_chans.update(noisy[bad_type])
noisy["bad_all"] = list(bad_chans)
noisy["bad_by_manual"] = self.bads_manual
noisy["bad_all"] = list(bad_chans) + self.bads_manual
logger.info(f"Bad channels: {noisy}")

if (
Expand Down
14 changes: 14 additions & 0 deletions tests/test_find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ def test_bad_by_nan(raw_tmp):
assert nd.bad_by_nan == [raw_tmp.ch_names[nan_idx]]


def test_bad_by_manual(raw_tmp):
"""Test the detection of channels marked bad a priori."""
n_chans = raw_tmp.get_data().shape[0]
nan_idx = int(rng.integers(0, n_chans, 1)[0])
raw_tmp._data[nan_idx, 3] = np.nan
raw_tmp.info["bads"] = [raw_tmp.ch_names[0]]

# Test record of a priori bad channels on NoisyChannels init
nd = NoisyChannels(raw_tmp, do_detrend=False)
assert nd.bad_by_manual == [raw_tmp.ch_names[0]]
assert raw_tmp.ch_names[0] in nd.get_bads(as_dict=False)
raw_tmp.info["bads"] = []


def test_bad_by_flat(raw_tmp):
"""Test the detection of channels with flat or very weak signals."""
# Make the signal for a random channel extremely weak
Expand Down
1 change: 1 addition & 0 deletions tests/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_basic_input(raw, montage):
reference = Reference(raw_tmp, params, ransac=False)
reference.perform_reference()
assert isinstance(reference.noisy_channels, dict)
assert isinstance(reference.noisy_channels["bad_by_manual"], list)
assert isinstance(reference.noisy_channels_original, dict)
assert isinstance(reference.bad_before_interpolation, list)
assert isinstance(reference.reference_signal, np.ndarray)
Expand Down

0 comments on commit 89e6fc9

Please sign in to comment.