diff --git a/tests/conftest.py b/tests/conftest.py index 7983483f..626c0e68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,15 +14,55 @@ def montage(): @pytest.fixture(scope="session") def raw(): - """Fixture for physionet EEG subject 4, dataset 1.""" + """Return an `mne.io.Raw` object for use with unit tests. + + This fixture downloads and reads in subject 4, run 1 from the Physionet + BCI2000 (eegbci) open dataset. This recording is quite noisy and is thus a + good candidate for testing the PREP pipeline. + + File attributes: + - Channels: 64 EEG + - Sample rate: 160 Hz + - Duration: 61 seconds + + This is only run once per session to save time downloading. + + """ mne.set_log_level("WARNING") - # load in subject 1, run 1 dataset + + # Download and read S004R01.edf from the BCI2000 dataset edf_fpath = eegbci.load_data(4, 1, update_path=True)[0] + raw = mne.io.read_raw_edf(edf_fpath, preload=True) + eegbci.standardize(raw) # Fix non-standard channel names + + return raw + + +@pytest.fixture(scope="session") +def raw_clean(montage): + """Return an `mne.io.Raw` object with no bad channels for use with tests. + + This fixture downloads and reads in subject 30, run 2 from the Physionet + BCI2000 (eegbci) open dataset, which contains no bad channels on an initial + pass of :class:`pyprep.NoisyChannels`. Intended for use with tests where + channels are made artifically bad. + + File attributes: + - Channels: 64 EEG + - Sample rate: 160 Hz + - Duration: 61 seconds + + This is only run once per session to save time downloading. + + """ + mne.set_log_level("WARNING") - # using sample EEG data (https://physionet.org/content/eegmmidb/1.0.0/) + # Download and read S030R02.edf from the BCI2000 dataset + edf_fpath = eegbci.load_data(30, 2, update_path=True)[0] raw = mne.io.read_raw_edf(edf_fpath, preload=True) + eegbci.standardize(raw) # Fix non-standard channel names - # The eegbci data has non-standard channel names. We need to rename them: - eegbci.standardize(raw) + # Set a montage for use with RANSAC + raw.set_montage(montage) return raw diff --git a/tests/test_find_noisy_channels.py b/tests/test_find_noisy_channels.py index 8d45ae3a..daf735ef 100644 --- a/tests/test_find_noisy_channels.py +++ b/tests/test_find_noisy_channels.py @@ -1,155 +1,212 @@ """Test the find_noisy_channels module.""" import numpy as np +from numpy.random import RandomState import pytest from pyprep.find_noisy_channels import NoisyChannels from pyprep.removeTrend import removeTrend -@pytest.mark.usefixtures("raw", "montage") -def test_findnoisychannels(raw, montage): - """Test find noisy channels.""" - # Set a random state for the test - rng = np.random.RandomState(30) +# Set a fixed random seed for reproducible test results - raw.set_montage(montage) - nd = NoisyChannels(raw, random_state=rng) - nd.find_all_bads(ransac=True) - bads = nd.get_bads() - iterations = ( - 10 # remove any noisy channels by interpolating the bads for 10 iterations +RNG = RandomState(30) + + +# Define some fixtures and utility functions for use across multiple tests + +@pytest.fixture(scope="session") +def raw_clean_detrend(raw_clean): + """Return a pre-detrended `mne.io.Raw` object with no bad channels. + + Based on the data from the `raw_clean` fixture, which uses the data for + subject 1, run 1 from the Physionet BCI2000 dataset. + + This is only run once per session to save time. + + """ + raw_clean_detrended = raw_clean.copy() + raw_clean_detrended._data = removeTrend( + raw_clean.get_data(), + raw_clean.info["sfreq"] ) - for iter in range(0, iterations): - if len(bads) == 0: - continue - raw.info["bads"] = bads - raw.interpolate_bads() - nd = NoisyChannels(raw, random_state=rng) - nd.find_all_bads(ransac=True) - bads = nd.get_bads() - - # make sure no bad channels exist in the data - raw.drop_channels(ch_names=bads) - - # Test for NaN and flat channels - raw_tmp = raw.copy() - m, n = raw_tmp._data.shape - # Insert a nan value for a random channel and make another random channel - # completely flat (ones) - idxs = rng.choice(np.arange(m), size=2, replace=False) - rand_chn_idx1 = idxs[0] - rand_chn_idx2 = idxs[1] - rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1] - rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2] - raw_tmp._data[rand_chn_idx1, n - 1] = np.nan - raw_tmp._data[rand_chn_idx2, :] = np.ones(n) - nd = NoisyChannels(raw_tmp, random_state=rng) + return raw_clean_detrended + + +@pytest.fixture +def raw_tmp(raw_clean_detrend): + """Return an unmodified copy of the `raw_clean_detrend` fixture. + + This is run once per NoisyChannels test, to keep any modifications to + `raw_tmp` during a test from affecting `raw_tmp` in any others. + + """ + raw_tmp = raw_clean_detrend.copy() + return raw_tmp + + +def _generate_signal(fmin, fmax, timepoints, fcount=1): + """Generate an EEG signal from one or more sine waves in a frequency range.""" + signal = np.zeros_like(timepoints) + for freq in RNG.randint(fmin, fmax+1, fcount): + signal += np.sin(2 * np.pi * timepoints * freq) + return signal * 1e-6 + + +# Run unit tests for each bad channel type detected by NoisyChannels + +def test_bad_by_nan(raw_tmp): + """Test the detection of channels containing any NaN values.""" + # Insert a NaN value into a random channel + n_chans = raw_tmp._data.shape[0] + nan_idx = int(RNG.randint(0, n_chans, 1)) + raw_tmp._data[nan_idx, 3] = np.nan + + # Test automatic detection of NaN channels on NoisyChannels init + nd = NoisyChannels(raw_tmp, do_detrend=False) + assert nd.bad_by_nan == [raw_tmp.ch_names[nan_idx]] + + # Test manual re-running of NaN channel detection nd.find_bad_by_nan_flat() - assert nd.bad_by_nan == [rand_chn_lab1] - assert nd.bad_by_flat == [rand_chn_lab2] - - # Test for high and low deviations in EEG data - raw_tmp = raw.copy() - m, n = raw_tmp._data.shape - # Now insert one random channel with very low deviations - rand_chn_idx = int(rng.randint(0, m, 1)) - rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] - raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10 - nd = NoisyChannels(raw_tmp, random_state=rng) - nd.find_bad_by_deviation() - assert rand_chn_lab in nd.bad_by_deviation - # Inserting one random channel with a high deviation - raw_tmp = raw.copy() - rand_chn_idx = int(rng.randint(0, m, 1)) - rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] - arbitrary_scaling = 5 - raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling - nd = NoisyChannels(raw_tmp, random_state=rng) + assert nd.bad_by_nan == [raw_tmp.ch_names[nan_idx]] + + +def test_bad_by_flat(raw_tmp): + """Test the detection of channels with flat or very weak signals.""" + # Make the signal for a random channel extremely weak + n_chans = raw_tmp._data.shape[0] + flat_idx = int(RNG.randint(0, n_chans, 1)) + raw_tmp._data[flat_idx, :] = raw_tmp._data[flat_idx, :] * 1e-12 + + # Test automatic detection of flat channels on NoisyChannels init + nd = NoisyChannels(raw_tmp, do_detrend=False) + assert nd.bad_by_flat == [raw_tmp.ch_names[flat_idx]] + + # Test manual re-running of flat channel detection + nd.find_bad_by_nan_flat() + assert nd.bad_by_flat == [raw_tmp.ch_names[flat_idx]] + + # Test detection when channel is completely flat + raw_tmp._data[flat_idx, :] = 0 + nd = NoisyChannels(raw_tmp, do_detrend=False) + assert nd.bad_by_flat == [raw_tmp.ch_names[flat_idx]] + + +def test_bad_by_deviation(raw_tmp): + """Test detection of channels with relatively high or low amplitudes.""" + # Set scaling factors for high and low deviation test channels + low_dev_factor = 0.1 + high_dev_factor = 4.0 + + # Make the signal for a random channel have a very high amplitude + n_chans = raw_tmp._data.shape[0] + high_dev_idx = int(RNG.randint(0, n_chans, 1)) + raw_tmp._data[high_dev_idx, :] *= high_dev_factor + + # Test detection of abnormally high-amplitude channels + nd = NoisyChannels(raw_tmp, do_detrend=False) nd.find_bad_by_deviation() - assert rand_chn_lab in nd.bad_by_deviation - - # Test for correlation between EEG channels - raw_tmp = raw.copy() - m, n = raw_tmp._data.shape - rand_chn_idx = int(rng.randint(0, m, 1)) - rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] - # Use cosine instead of sine to create a signal - low = 10 - high = 30 - n_freq = 5 - signal = np.zeros((1, n)) - for freq_i in range(n_freq): - freq = rng.randint(low, high, n) - signal[0, :] += np.cos(2 * np.pi * raw.times * freq) - raw_tmp._data[rand_chn_idx, :] = signal * 1e-6 - nd = NoisyChannels(raw_tmp, random_state=rng) - nd.find_bad_by_correlation() - assert rand_chn_lab in nd.bad_by_correlation - bad_by_correlation_orig = nd.bad_by_correlation # save for dropout tests - - # Test for channels with signal dropouts (reuse data from correlation tests) - dropout_idx = rand_chn_idx - 1 if rand_chn_idx > 0 else 1 - # Make 2nd and 4th quarters of the dropout channel completely flat - raw_tmp._data[dropout_idx, :int(n/4)] = 0 - raw_tmp._data[dropout_idx, int(3*n/4):] = 0 - # Run correlation and dropout detection on data - nd = NoisyChannels(raw_tmp, random_state=rng) - nd.find_bad_by_correlation() # also does dropout detection - # Test if dropout channel detected correctly - assert raw_tmp.ch_names[dropout_idx] in nd.bad_by_dropout - # Test if correlations still detected correctly - bad_orig_plus_dropout = bad_by_correlation_orig + nd.bad_by_dropout - same_bads = set(nd.bad_by_correlation) == set(bad_by_correlation_orig) - same_plus_dropout = set(nd.bad_by_correlation) == set(bad_orig_plus_dropout) - assert same_bads or same_plus_dropout - - # Test for high freq noise detection - raw_tmp = raw.copy() - m, n = raw_tmp._data.shape - rand_chn_idx = int(rng.randint(0, m, 1)) - rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] - # Use freqs between 90 and 100 Hz to insert hf noise - signal = np.zeros((1, n)) - for freq_i in range(n_freq): - freq = rng.randint(90, 100, n) - signal[0, :] += np.sin(2 * np.pi * raw.times * freq) - raw_tmp._data[rand_chn_idx, :] = signal * 1e-6 - nd = NoisyChannels(raw_tmp, random_state=rng) + assert nd.bad_by_deviation == [raw_tmp.ch_names[high_dev_idx]] + + # Make the signal for a different channel have a very low amplitude + low_dev_idx = (high_dev_idx - 1) if high_dev_idx > 0 else 1 + raw_tmp._data[low_dev_idx, :] *= low_dev_factor + + # Test detection of abnormally low-amplitude channels + # NOTE: The default z-score threshold (5.0) is too strict to allow detection + # of abnormally low-amplitude channels in some datasets. Using a relaxed Z + # threshold of 3.29 (p < 0.001, two-tailed) until a better solution is found. + nd = NoisyChannels(raw_tmp, do_detrend=False) + nd.find_bad_by_deviation(deviation_threshold=3.29) + bad_by_dev_idx = [low_dev_idx, high_dev_idx] + assert nd.bad_by_deviation == [raw_tmp.ch_names[i] for i in bad_by_dev_idx] + + +def test_bad_by_hf_noise(raw_tmp): + """Test detection of channels with high-frequency noise.""" + # Add some noise between 70 & 80 Hz to the signal of a random channel + n_chans = raw_tmp._data.shape[0] + hf_noise_idx = int(RNG.randint(0, n_chans, 1)) + hf_noise = _generate_signal(70, 80, raw_tmp.times, 5) * 10 + raw_tmp._data[hf_noise_idx, :] += hf_noise + + # Test detection of channels with high-frequency noise + nd = NoisyChannels(raw_tmp, do_detrend=False) nd.find_bad_by_hfnoise() - assert rand_chn_lab in nd.bad_by_hf_noise + assert nd.bad_by_hf_noise == [raw_tmp.ch_names[hf_noise_idx]] - # Test for high freq noise detection when sample rate < 100 Hz - raw_tmp.resample(80) # downsample to 80 Hz - nd = NoisyChannels(raw_tmp, random_state=rng) + # Test lack of high-frequency noise detection when sample rate < 100 Hz + raw_tmp.resample(80) # downsample from 160 Hz to 80 Hz + nd = NoisyChannels(raw_tmp, do_detrend=False) nd.find_bad_by_hfnoise() assert len(nd.bad_by_hf_noise) == 0 + assert nd._extra_info['bad_by_hf_noise']['median_channel_noisiness'] == 0 + assert nd._extra_info['bad_by_hf_noise']['channel_noisiness_sd'] == 1 + + +def test_bad_by_dropout(raw_tmp): + """Test detection of channels with excessive portions of flat signal.""" + # Add large dropout portions to the signal of a random channel + n_chans, n_samples = raw_tmp._data.shape + dropout_idx = int(RNG.randint(0, n_chans, 1)) + x1, x2 = (int(n_samples / 10), int(2 * n_samples / 10)) + raw_tmp._data[dropout_idx, x1:x2] = 0 # flatten 10% of signal + + # Test detection of channels that have excessive dropout regions + nd = NoisyChannels(raw_tmp, do_detrend=False) + nd.find_bad_by_correlation() + assert nd.bad_by_dropout == [raw_tmp.ch_names[dropout_idx]] + - # Test for signal to noise ratio in EEG data - raw_tmp = raw.copy() - m, n = raw_tmp._data.shape - rand_chn_idx = int(rng.randint(0, m, 1)) - rand_chn_lab = raw_tmp.ch_names[rand_chn_idx] - # inserting an uncorrelated high frequency (90 Hz) signal in one channel - raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6 - nd = NoisyChannels(raw_tmp, random_state=rng) +def test_bad_by_correlation(raw_tmp): + """Test detection of channels that correlate poorly with others.""" + # Replace a random channel's signal with uncorrelated values + n_chans, n_samples = raw_tmp._data.shape + low_corr_idx = int(RNG.randint(0, n_chans, 1)) + raw_tmp._data[low_corr_idx, :] = _generate_signal(10, 30, raw_tmp.times, 5) + + # Test detection of channels that correlate poorly with others + nd = NoisyChannels(raw_tmp, do_detrend=False) + nd.find_bad_by_correlation() + assert nd.bad_by_correlation == [raw_tmp.ch_names[low_corr_idx]] + + # Add a channel with dropouts to see if correlation detection still works + dropout_idx = (low_corr_idx - 1) if low_corr_idx > 0 else 1 + x1, x2 = (int(n_samples / 10), int(2 * n_samples / 10)) + raw_tmp._data[dropout_idx, x1:x2] = 0 # flatten 10% of signal + + # Re-test detection of channels that correlate poorly with others + # (only new bad-by-correlation channel should be dropout) + nd = NoisyChannels(raw_tmp, do_detrend=False) + nd.find_bad_by_correlation() + assert raw_tmp.ch_names[low_corr_idx] in nd.bad_by_correlation + assert len(nd.bad_by_correlation) <= 2 + + +def test_bad_by_SNR(raw_tmp): + """Test detection of channels that have low signal-to-noise ratios.""" + # Replace a random channel's signal with uncorrelated values + n_chans = raw_tmp._data.shape[0] + low_snr_idx = int(RNG.randint(0, n_chans, 1)) + raw_tmp._data[low_snr_idx, :] = _generate_signal(10, 30, raw_tmp.times, 5) + + # Add some high-frequency noise to the uncorrelated channel + hf_noise = _generate_signal(70, 80, raw_tmp.times, 5) * 10 + raw_tmp._data[low_snr_idx, :] += hf_noise + + # Test detection of channels with a low signal-to-noise ratio + nd = NoisyChannels(raw_tmp, do_detrend=False) nd.find_bad_by_SNR() - assert rand_chn_lab in nd.bad_by_SNR + assert nd.bad_by_SNR == [raw_tmp.ch_names[low_snr_idx]] -@pytest.mark.usefixtures("raw", "montage") -def test_find_bad_by_ransac(raw, montage): +def test_find_bad_by_ransac(raw_tmp): """Test the RANSAC component of NoisyChannels.""" - # Set a fixed random seed and a montage for the tests - rng = 435656 - raw.set_montage(montage) + # Set a consistent random seed for all RANSAC runs + RANSAC_RNG = 435656 # RANSAC identifies channels that go bad together and are highly correlated. - # Inserting highly correlated signal in channels 0 through 3 at 30 Hz - raw_tmp = raw.copy() - raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6 - - # Pre-detrend data to save time during NoisyChannels initialization - raw_tmp._data = removeTrend(raw_tmp.get_data(), raw.info["sfreq"]) + # Inserting highly correlated signal in channels 0 through 6 at 30 Hz + raw_tmp._data[0:6, :] = _generate_signal(30, 30, raw_tmp.times) # Run different variations of RANSAC on the same data test_matrix = { @@ -164,7 +221,7 @@ def test_find_bad_by_ransac(raw, montage): corr = {} for name, args in test_matrix.items(): nd = NoisyChannels( - raw_tmp, do_detrend=False, random_state=rng, matlab_strict=args[0] + raw_tmp, do_detrend=False, random_state=RANSAC_RNG, matlab_strict=args[0] ) nd.find_bad_by_ransac(channel_wise=args[1], max_chunk_size=args[2]) # Save bad channels and RANSAC correlation matrix for later comparison @@ -188,37 +245,30 @@ def test_find_bad_by_ransac(raw, montage): # Make sure strict and non-strict matrices differ assert not np.allclose(corr['by_window'], corr['by_window_strict']) + +def test_find_bad_by_ransac_err(raw_tmp): + """Test error handling in the `find_bad_by_ransac` method.""" # Set n_samples very very high to trigger a memory error n_samples = int(1e100) - nd = NoisyChannels(raw_tmp, do_detrend=False, random_state=rng) + nd = NoisyChannels(raw_tmp, do_detrend=False) with pytest.raises(MemoryError): nd.find_bad_by_ransac(n_samples=n_samples) # Set n_samples to a float to trigger a type error n_samples = 35.5 - nd = NoisyChannels(raw_tmp, do_detrend=False, random_state=rng) + nd = NoisyChannels(raw_tmp, do_detrend=False) with pytest.raises(TypeError): nd.find_bad_by_ransac(n_samples=n_samples) # Test IOError when too few good channels for RANSAC sample size - raw_tmp = raw.copy() - nd = NoisyChannels(raw_tmp, random_state=rng) - nd.find_all_bads(ransac=False) - # Make 80% of channels bad - num_bad_channels = int(raw._data.shape[0] * 0.8) - bad_channels = raw.info["ch_names"][0:num_bad_channels] - nd.bad_by_deviation = bad_channels + n_chans = raw_tmp._data.shape[0] + nd = NoisyChannels(raw_tmp, do_detrend=False) + nd.bad_by_deviation = raw_tmp.info["ch_names"][0:int(n_chans * 0.8)] with pytest.raises(IOError): nd.find_bad_by_ransac() - # Test IOError when not enough channels for ransac predictions - raw_tmp = raw.copy() - # Make flat all channels except 2 - num_bad_channels = raw._data.shape[0] - 2 - raw_tmp._data[0:num_bad_channels, :] = np.zeros_like( - raw_tmp._data[0:num_bad_channels, :] - ) - nd = NoisyChannels(raw_tmp, random_state=rng) - nd.find_all_bads(ransac=False) + # Test IOError when not enough channels for RANSAC predictions + raw_tmp._data[0:(n_chans - 2), :] = 0 # make all channels flat except 2 + nd = NoisyChannels(raw_tmp, do_detrend=False) with pytest.raises(IOError): nd.find_bad_by_ransac() diff --git a/tests/test_matprep_compare.py b/tests/test_matprep_compare.py index 9f191eaf..a34dfae2 100644 --- a/tests/test_matprep_compare.py +++ b/tests/test_matprep_compare.py @@ -164,15 +164,12 @@ class TestCompareNoisyChannels(object): def test_bad_by_nan(self, pyprep_noisy, matprep_noisy): """Compare bad-by-NaN results between PyPREP and MatPREP.""" - # NOTE: The current test artifacts contain no channels with NaN values - # (when does that ever happen?), meaning that this may not be testing - # anything useful + # Compare names of bad-by-NaN channels assert pyprep_noisy.bad_by_nan == matprep_noisy['bads']['by_nan'] def test_bad_by_flat(self, pyprep_noisy, matprep_noisy): """Compare bad-by-flat results between PyPREP and MatPREP.""" - # NOTE: The current test artifacts contain no flat channels, meaning - # that this may not be testing anything useful + # Compare names of bad-by-flat channels assert pyprep_noisy.bad_by_flat == matprep_noisy['bads']['by_flat'] def test_bad_by_deviation(self, pyprep_noisy, matprep_noisy): @@ -264,9 +261,6 @@ def test_bad_by_SNR(self, pyprep_noisy, matprep_noisy): def test_bad_by_dropout(self, pyprep_noisy, matprep_noisy): """Compare bad-by-dropout results between PyPREP and MatPREP.""" - # NOTE: The current test artifacts contain no channels with dropouts, - # meaning that this may not be testing anything useful - # Gather PyPREP and MATLAB PREP dropout info matprep_dropouts = matprep_noisy['dropOuts'] pyprep_dropouts = pyprep_noisy._extra_info['bad_by_dropout']['dropouts'] diff --git a/tests/test_reference.py b/tests/test_reference.py index ed96a2ff..5c87df1d 100644 --- a/tests/test_reference.py +++ b/tests/test_reference.py @@ -8,12 +8,13 @@ from pyprep.reference import Reference -@pytest.mark.usefixtures("raw") -def test_basic_input(raw): +@pytest.mark.usefixtures("raw", "montage") +def test_basic_input(raw, montage): """Test Reference output data type.""" ch_names = raw.info["ch_names"] raw_tmp = raw.copy() + raw_tmp.set_montage(montage) params = {"ref_chs": ch_names, "reref_chs": ch_names} reference = Reference(raw_tmp, params, ransac=False) reference.perform_reference() @@ -26,12 +27,13 @@ def test_basic_input(raw): assert type(reference.raw) == mne.io.edf.edf.RawEDF -@pytest.mark.usefixtures("raw") -def test_all_bad_input(raw): +@pytest.mark.usefixtures("raw", "montage") +def test_all_bad_input(raw, montage): """Test robust reference when all reference channels are bad.""" ch_names = raw.info["ch_names"] raw_tmp = raw.copy() + raw_tmp.set_montage(montage) m, n = raw_tmp.get_data().shape # Randomly set some channels as bad diff --git a/tests/test_utils.py b/tests/test_utils.py index 91c4383b..7e9ef624 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ """Test various helper functions.""" import numpy as np +from numpy.random import RandomState from pyprep.utils import ( _mat_round, _mat_quantile, _mat_iqr, _get_random_subset, _correlate_arrays, @@ -79,7 +80,7 @@ def test_mat_quantile_iqr(): def test_get_random_subset(): """Test the function for getting random channel subsets.""" # Generate test data - rng = np.random.RandomState(435656) + rng = RandomState(435656) chans = range(1, 61) # Compare random subset equivalence with MATLAB