From 1fc9fcc09b0134f78aebe012c011a1c5b335a71c Mon Sep 17 00:00:00 2001 From: Nabil-AL Date: Sun, 14 Apr 2024 15:08:15 +0200 Subject: [PATCH] ADD prototype of bad_by_PSD() method --- pyprep/find_noisy_channels.py | 44 ++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index d849fd2..ba48cc6 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -4,7 +4,7 @@ import mne import numpy as np from mne.utils import check_random_state, logger -from scipy import signal +from scipy import signal, stats from pyprep.ransac import find_bad_by_ransac from pyprep.removeTrend import removeTrend @@ -70,6 +70,7 @@ def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False) "bad_by_hf_noise": {}, "bad_by_correlation": {}, "bad_by_dropout": {}, + "bad_by_psd": {}, "bad_by_ransac": {}, } @@ -84,6 +85,7 @@ def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False) self.bad_by_correlation = [] self.bad_by_SNR = [] self.bad_by_dropout = [] + self.bad_by_psd = [] self.bad_by_ransac = [] # Get original EEG channel names, channel count & samples @@ -486,6 +488,46 @@ def find_bad_by_SNR(self): # Flag channels bad by both HF noise and low correlation as bad by low SNR self.bad_by_SNR = list(bad_by_corr.intersection(bad_by_hf)) + def find_bad_by_PSD(self, zscore_threshold=3.0): + """ + Detect channels with abnormally high or low overall power spectral density + (PSD) values. + + A channel is considered "bad-by-psd" if its psd value deviates + considerably from the median channel psd, as calculated using a + Z-scoring method and the given z-score threshold. + PSD calculation is done using the Welch method. + Uses the Welch method for PSD calculation + + Parameters + ---------- + zscore_threshold : float, optional + The minimum noisiness z-score of a channel for it to be considered + bad-by-psd. Defaults to ``3.0``. + """ + if self.EEGFiltered is None: + self.EEGFiltered = self._get_filtered_data() + psd = self.EEGFiltered.compute_psd(method='welch', fmin=1, fmax=50) + log_psd = 10 * np.log10(psd.get_data()) + median_channel_psd = np.median(log_psd, axis=0) + + # # Calculate robust Z-scores for the channel amplitudes + psd_zscore = np.zeros(self.n_chans_original) + psd_zscore[self.usable_idx] = stats.zscore(np.sum(log_psd - median_channel_psd, axis=1)) + + # Flag channels with unusually high or low PSD values compared to the median channel + psd_channel_mask = np.isnan(psd_zscore) | (psd_zscore > zscore_threshold) + abnormal_psd_channels = self.ch_names_original[psd_channel_mask] + + # Update names of bad channels by abnormal PSD & save additional info + self.bad_by_psd = abnormal_psd_channels.tolist() + self._extra_info["bad_by_psd"].update( + { + "median_channel_psd": median_channel_psd, + "psd_zscore": psd_zscore, + } + ) + def find_bad_by_ransac( self, n_samples=50,