Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make noisy channel exclusion during Reference compatible with MATLAB PREP #93

Merged
merged 16 commits into from
Jun 27, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,46 +125,61 @@ def _get_filtered_data(self):

return EEG_filt

def get_bads(self, verbose=False):
"""Get a list of all channels currently flagged as bad.
def get_bads(self, verbose=False, as_dict=False):
"""Get the names of all channels currently flagged as bad.

Note that this method does not perform any bad channel detection itself,
and only reports channels already detected as bad by other methods.

Parameters
----------
verbose : bool
verbose : bool, optional
If ``True``, a summary of the channels currently flagged as by bad per
category is printed. Defaults to ``False``.
as_dict: bool, optional
If ``True``, this method will return a dict of the channels currently
flagged as bad by each individual bad channel type. If ``False``, this
method will return a list of all unique bad channels detected so far.
Defaults to ``False``.

Returns
-------
bads : list
THe names of all bad channels detected by any method so far.
bads : list or dict
The names of all bad channels detected so far, either as a combined
list or a dict indicating the channels flagged bad by each type.

"""
bads = {
"n/a": self.bad_by_nan,
"flat": self.bad_by_flat,
"deviation": self.bad_by_deviation,
"hf noise": self.bad_by_hf_noise,
"correl": self.bad_by_correlation,
"SNR": self.bad_by_SNR,
"dropout": self.bad_by_dropout,
"RANSAC": self.bad_by_ransac,
"bad_by_nan": self.bad_by_nan,
"bad_by_flat": self.bad_by_flat,
"bad_by_deviation": self.bad_by_deviation,
"bad_by_hf_noise": self.bad_by_hf_noise,
"bad_by_correlation": self.bad_by_correlation,
"bad_by_SNR": self.bad_by_SNR,
"bad_by_dropout": self.bad_by_dropout,
"bad_by_ransac": self.bad_by_ransac,
}

all_bads = set()
for bad_chs in bads.values():
all_bads.update(bad_chs)

name_map = {"nan": "NaN", "hf_noise": "HF noise", "ransac": "RANSAC"}
if verbose:
out = f"Found {len(all_bads)} uniquely bad channels:\n"
for bad_type, bad_chs in bads.items():
bad_type = bad_type.replace("bad_by_", "")
if bad_type in name_map.keys():
bad_type = name_map[bad_type]
out += f"\n{len(bad_chs)} by {bad_type}: {bad_chs}\n"
print(out)

return list(all_bads)
if as_dict:
bads["bad_all"] = list(all_bads)
else:
bads = list(all_bads)

return bads

def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None):
"""Call all the functions to detect bad channels.
Expand Down
99 changes: 31 additions & 68 deletions pyprep/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,7 @@ def perform_reference(self):
# Record Noisy channels and EEG before interpolation
self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
self.EEG_before_interpolation = self.EEG.copy()
self.noisy_channels_before_interpolation = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_before_interpolation = noisy_detector.get_bads(as_dict=True)
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
self._extra_info["interpolated"] = noisy_detector._extra_info

bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
Expand All @@ -170,17 +160,7 @@ def perform_reference(self):
noisy_detector.find_all_bads(**self.ransac_settings)
self.still_noisy_channels = noisy_detector.get_bads()
self.raw.info["bads"] = self.still_noisy_channels
self.noisy_channels_after_interpolation = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_after_interpolation = noisy_detector.get_bads(as_dict=True)
self._extra_info["remaining_bad"] = noisy_detector._extra_info

return self
Expand Down Expand Up @@ -213,17 +193,7 @@ def robust_reference(self):
matlab_strict=self.matlab_strict,
)
noisy_detector.find_all_bads(**self.ransac_settings)
self.noisy_channels_original = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_original = noisy_detector.get_bads(as_dict=True)
self._extra_info["initial_bad"] = noisy_detector._extra_info
logger.info("Bad channels: {}".format(self.noisy_channels_original))

Expand All @@ -235,16 +205,16 @@ def robust_reference(self):
reference_channels = _set_diff(self.reference_channels, self.unusable_channels)

# Initialize channels to permanently flag as bad during referencing
self.noisy_channels = {
noisy = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": [],
"bad_by_hf_noise": [],
"bad_by_correlation": [],
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_SNR": [],
"bad_by_dropout": [],
"bad_by_ransac": [],
"bad_all": self.unusable_channels,
"bad_all": [],
}

# Get initial estimate of the reference by the specified method
Expand All @@ -260,7 +230,7 @@ def robust_reference(self):
# Remove reference from signal, iteratively interpolating bad channels
raw_tmp = raw.copy()
iterations = 0
noisy_channels_old = []
previous_bads = set()
max_iteration_num = 4
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved

while True:
Expand All @@ -272,51 +242,43 @@ def robust_reference(self):
matlab_strict=self.matlab_strict,
)
# Detrend applied at the beginning of the function.

# Detect all currently bad channels
noisy_detector.find_all_bads(**self.ransac_settings)
self.noisy_channels["bad_by_nan"] = _union(
self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan
)
self.noisy_channels["bad_by_flat"] = _union(
self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat
)
self.noisy_channels["bad_by_deviation"] = _union(
self.noisy_channels["bad_by_deviation"], noisy_detector.bad_by_deviation
)
self.noisy_channels["bad_by_hf_noise"] = _union(
self.noisy_channels["bad_by_hf_noise"], noisy_detector.bad_by_hf_noise
)
self.noisy_channels["bad_by_correlation"] = _union(
self.noisy_channels["bad_by_correlation"],
noisy_detector.bad_by_correlation,
)
self.noisy_channels["bad_by_ransac"] = _union(
self.noisy_channels["bad_by_ransac"], noisy_detector.bad_by_ransac
)
self.noisy_channels["bad_all"] = _union(
self.noisy_channels["bad_all"], noisy_detector.get_bads()
)
logger.info("Bad channels: {}".format(self.noisy_channels))
noisy_new = noisy_detector.get_bads(as_dict=True)

# Specify bad channel types to ignore when updating noisy channels
# NOTE: MATLAB PREP ignores dropout channels, possibly by mistake?
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
ignore = ["bad_by_SNR", "bad_all"]
if self.matlab_strict:
ignore += ["bad_by_dropout"]

# Update set of all noisy channels detected so far with any new ones
bad_chans = set()
for bad_type in noisy_new.keys():
noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type])
if bad_type not in ignore:
bad_chans.update(noisy[bad_type])
noisy["bad_all"] = list(bad_chans)
logger.info("Bad channels: {}".format(noisy))

if (
iterations > 1
and (
not self.noisy_channels["bad_all"]
or set(self.noisy_channels["bad_all"]) == set(noisy_channels_old)
)
and (len(bad_chans) == 0 or bad_chans == previous_bads)
or iterations > max_iteration_num
):
break
noisy_channels_old = self.noisy_channels["bad_all"].copy()
previous_bads = bad_chans.copy()

if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2:
if raw_tmp.info["nchan"] - len(bad_chans) < 2:
raise ValueError(
"RobustReference:TooManyBad "
"Could not perform a robust reference -- not enough good channels"
)

if self.noisy_channels["bad_all"]:
if len(bad_chans) > 0:
raw_tmp._data = signal * 1e-6
raw_tmp.info["bads"] = self.noisy_channels["bad_all"]
raw_tmp.info["bads"] = list(bad_chans)
raw_tmp.interpolate_bads()
signal_tmp = raw_tmp.get_data() * 1e6
else:
Expand All @@ -332,6 +294,7 @@ def robust_reference(self):
logger.info("Iterations: {}".format(iterations))

logger.info("Robust reference done")
self.noisy_channels = noisy
return self.noisy_channels, self.reference_signal

@staticmethod
Expand Down