Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ransac Comparison Example #53

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.cfg') }}-${{ hashFiles('requirements-dev.txt') }}
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.cfg') }}-${{ hashFiles('requirements-dev.txt') }}-version-1

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Changelog
- Channel types are now available from a new ``ch_types_all`` attribute, and non-EEG channel names are now available from a new ``ch_names_non_eeg`` attribute from :class:`PrepPipeline <pyprep.PrepPipeline>`, by `Yorguin Mantilla`_ (:gh:`34`)
- Renaming of ``ch_names`` attribute of :class:`PrepPipeline <pyprep.PrepPipeline>` to ``ch_names_all``, by `Yorguin Mantilla`_ (:gh:`34`)
- It's now possible to pass ``'eeg'`` to ``ref_chs`` and ``reref_chs`` keywords to the ``prep_params`` parameter of :class:`PrepPipeline <pyprep.PrepPipeline>` to select only eeg channels for referencing, by `Yorguin Mantilla`_ (:gh:`34`)
- :class:`PrepPipeline <pyprep.PrepPipeline>` will retain the non eeg channels through the ``raw`` attribute. The eeg-only and non-eeg parts will be in raw_eeg and raw_non_eeg respectively. See the ``raw`` attribute, by `Christian Oreilly`_ (:gh:`34`)
- :class:`PrepPipeline <pyprep.PrepPipeline>` will retain the non eeg channels through the ``raw`` attribute. The eeg-only and non-eeg parts will be in raw_eeg and raw_non_eeg respectively. See the ``raw`` attribute, by `Christian O'Reilly`_ (:gh:`34`)
- When a ransac call needs more memory than available, pyprep will now automatically switch to a slower but less memory-consuming version of ransac, by `Yorguin Mantilla`_ (:gh:`32`)
- It's now possible to pass an empty list for the ``line_freqs`` param in :class:`PrepPipeline <pyprep.PrepPipeline>` to skip the line noise removal, by `Yorguin Mantilla`_ (:gh:`29`)
- The three main classes :class:`~pyprep.PrepPipeline`, :class:`~pyprep.NoisyChannels`, and :class:`pyprep.Reference` now have a ``random_state`` parameter to set a seed that gets passed on to all their internal methods and class calls, by `Stefan Appelhoff`_ (:gh:`31`)
Expand Down
212 changes: 212 additions & 0 deletions examples/ransac_comparison_autoreject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""
===============================================
RANSAC comparison between pyprep and autoreject
===============================================

Next to the RANSAC implementation in ``pyprep``,
there is another implementation that makes use of MNE-Python.
That alternative RANSAC implementation can be found in the
`"autoreject" package <https://github.com/autoreject/autoreject/>`_.

In this example, we make a basic comparison between the two implementations.

#. by running them on the same simulated data
#. by running them on the same "real" data


.. currentmodule:: pyprep
"""

# Authors: Yorguin Mantilla <[email protected]>
#
# License: MIT

# %%
# First we import what we need for this example.
import numpy as np
import mne
from scipy import signal as signal
from time import perf_counter
from autoreject import Ransac
import pyprep.ransac as ransac_pyprep


# %%
# Now let's make some arbitrary MNE raw object for demonstration purposes.
# We will think of good channels as sine waves and bad channels correlated with
# each other as sawtooths. The RANSAC will be biased towards sines in its
# prediction (they are the majority) so it will identify the sawtooths as bad.

# Set a random seed to make this example reproducible
rng = np.random.RandomState(435656)

# start defining some key aspects for our simulated data
sfreq = 1000.0
montage = mne.channels.make_standard_montage("standard_1020")
ch_names = montage.ch_names
n_chans = len(ch_names)
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=["eeg"] * n_chans)
time = np.arange(0, 30, 1.0 / sfreq) # 30 seconds of recording

# randomly pick some "bad" channels (sawtooths)
n_bad_chans = 3
bad_channels = rng.choice(np.arange(n_chans), n_bad_chans, replace=False)
bad_channels = [int(i) for i in bad_channels]
bad_ch_names = [ch_names[i] for i in bad_channels]

# The frequency components to use in the signal for good and bad channels
freq_good = 20
freq_bad = 20

# Generate the data: sinewaves for "good", sawtooths for "bad" channels
X = [
signal.sawtooth(2 * np.pi * freq_bad * time)
if i in bad_channels
else np.sin(2 * np.pi * freq_good * time)
for i in range(n_chans)
]

# Scale the signal amplitude and add noise.
X = 2e-5 * np.array(X) + 1e-5 * np.random.random((n_chans, time.shape[0]))

# Finally, put it all together as an mne "Raw" object.
raw = mne.io.RawArray(X, info)
raw.set_montage(montage, verbose=False)

# Print, which channels are simulated as "bad"
print(bad_ch_names)

# %%
# Configure RANSAC parameters
n_samples = 50
fraction_good = 0.25
corr_thresh = 0.75
fraction_bad = 0.4
corr_window_secs = 5.0

# %%
# autoreject's RANSAC
ransac_ar = Ransac(
picks=None,
n_resample=n_samples,
min_channels=fraction_good,
min_corr=corr_thresh,
unbroken_time=fraction_bad,
n_jobs=1,
random_state=rng,
)
epochs = mne.make_fixed_length_epochs(
raw,
duration=corr_window_secs,
preload=True,
reject_by_annotation=False,
verbose=None,
)

start_time = perf_counter()
ransac_ar = ransac_ar.fit(epochs)
print("--- %s seconds ---" % (perf_counter() - start_time))

corr_ar = ransac_ar.corr_
bad_by_ransac_ar = ransac_ar.bad_chs_

# Check channels that go bad together by RANSAC
print("autoreject bad chs:", bad_by_ransac_ar)
assert set(bad_ch_names) == set(bad_by_ransac_ar)

# %%
# pyprep's RANSAC

start_time = perf_counter()
bad_by_ransac_pyprep, corr_pyprep = ransac_pyprep.find_bad_by_ransac(
data=raw._data.copy(),
sample_rate=raw.info["sfreq"],
complete_chn_labs=np.asarray(raw.info["ch_names"]),
chn_pos=raw._get_channel_positions(),
exclude=[],
n_samples=n_samples,
sample_prop=fraction_good,
corr_thresh=corr_thresh,
frac_bad=fraction_bad,
corr_window_secs=corr_window_secs,
channel_wise=False,
random_state=rng,
)
print("--- %s seconds ---" % (perf_counter() - start_time))

# Check channels that go bad together by RANSAC
print("pyprep bad chs:", bad_by_ransac_pyprep)
assert set(bad_ch_names) == set(bad_by_ransac_pyprep)

# %%
# Now we test the algorithms on real EEG data.
# Let's download some data for testing.
data_paths = mne.datasets.eegbci.load_data(subject=4, runs=1, update_path=True)
fname_test_file = data_paths[0]

# %%
# Load data and prepare it

raw = mne.io.read_raw_edf(fname_test_file, preload=True)

# The eegbci data has non-standard channel names. We need to rename them:
mne.datasets.eegbci.standardize(raw)

# Add a montage to the data
montage_kind = "standard_1005"
montage = mne.channels.make_standard_montage(montage_kind)
raw.set_montage(montage)


# %%
# autoreject's RANSAC
ransac_ar = Ransac(
picks=None,
n_resample=n_samples,
min_channels=fraction_good,
min_corr=corr_thresh,
unbroken_time=fraction_bad,
n_jobs=1,
random_state=rng,
)
epochs = mne.make_fixed_length_epochs(
raw,
duration=corr_window_secs,
preload=True,
reject_by_annotation=False,
verbose=None,
)

start_time = perf_counter()
ransac_ar = ransac_ar.fit(epochs)
print("--- %s seconds ---" % (perf_counter() - start_time))

corr_ar = ransac_ar.corr_
bad_by_ransac_ar = ransac_ar.bad_chs_

# Check channels that go bad together by RANSAC
print("autoreject bad chs:", bad_by_ransac_ar)


# %%
# pyprep's RANSAC

start_time = perf_counter()
bad_by_ransac_pyprep, corr_pyprep = ransac_pyprep.find_bad_by_ransac(
data=raw._data.copy(),
sample_rate=raw.info["sfreq"],
complete_chn_labs=np.asarray(raw.info["ch_names"]),
chn_pos=raw._get_channel_positions(),
exclude=[],
n_samples=n_samples,
sample_prop=fraction_good,
corr_thresh=corr_thresh,
frac_bad=fraction_bad,
corr_window_secs=corr_window_secs,
channel_wise=False,
random_state=rng,
)
print("--- %s seconds ---" % (perf_counter() - start_time))

# Check channels that go bad together by RANSAC
print("pyprep bad chs:", bad_by_ransac_pyprep)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add one more code block below to easily show:

  • chs bad that are equal between pyprep and AR
  • chs bad only AR
  • chs bad only pyprep

2 changes: 1 addition & 1 deletion examples/run_ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
# `raw` object. For more information, we can access attributes of the ``nd``
# instance:

# Check channels that go bad together by correlation (RANSAC)
# Check channels that go bad together by RANSAC
print(nd.bad_by_ransac)
assert set(bad_ch_names) == set(nd.bad_by_ransac)

Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
git+git://github.com/autoreject/autoreject.git@master#egg=autoreject
black
check-manifest
flake8
Expand Down