Skip to content

Commit

Permalink
Replace mne legacy functions with modern code (#161)
Browse files Browse the repository at this point in the history
* cleaner artifact u pload, include hidden files

* replace pick_channels with .pick where appropriate

* replace pick_types with .pick
  • Loading branch information
sappelhoff authored Nov 23, 2024
1 parent 8421b03 commit 9d3f984
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,12 @@ jobs:
make -C docs/ html
- name: Upload artifacts
if: matrix.platform == 'ubuntu-latest'
if: ${{ matrix.platform == 'ubuntu-latest' && matrix.mne-version == 'mne-stable' }}
uses: actions/upload-artifact@v4
with:
name: docs-artifact
path: docs/_build/html
include-hidden-files: true

- name: Upload coverage report
uses: codecov/codecov-action@v5
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
with:
name: dist
path: dist
include-hidden-files: true

pypi-upload:
needs: package
Expand Down
2 changes: 1 addition & 1 deletion pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False)
raw.load_data()
self.raw_mne = raw.copy()
self.bad_by_manual = raw.info["bads"]
self.raw_mne.pick_types(eeg=True)
self.raw_mne.pick("eeg") # excludes bads
self.sample_rate = raw.info["sfreq"]
if do_detrend:
self.raw_mne._data = removeTrend(
Expand Down
4 changes: 2 additions & 2 deletions pyprep/prep_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ def __init__(
if self.ch_types_all[i] == "eeg"
]
self.ch_names_non_eeg = list(set(self.ch_names_all) - set(self.ch_names_eeg))
self.raw_eeg.pick_channels(self.ch_names_eeg)
self.raw_eeg.pick(self.ch_names_eeg)
if self.ch_names_non_eeg == []:
self.raw_non_eeg = None
else:
self.raw_non_eeg = raw.copy()
self.raw_non_eeg.pick_channels(self.ch_names_non_eeg)
self.raw_non_eeg.pick(self.ch_names_non_eeg)

self.raw_eeg.set_montage(montage)
# raw_non_eeg may not be compatible with the montage
Expand Down
5 changes: 3 additions & 2 deletions pyprep/ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Authors: The PyPREP developers
# SPDX-License-Identifier: MIT

import mne
import numpy as np
from mne.channels.interpolation import _make_interpolation_matrix
from mne.utils import ProgressBar, check_random_state, logger
Expand Down Expand Up @@ -138,7 +137,9 @@ def find_bad_by_ransac(
# Get all channel positions and the position subset of "clean channels"
# Exclude should be the bad channels from other methods
# That is, identify all bad channels by other means
good_idx = mne.pick_channels(list(complete_chn_labs), include=[], exclude=exclude)
good_idx = np.array(
[idx for idx, ch in enumerate(complete_chn_labs) if ch not in exclude]
)
n_chans_good = good_idx.shape[0]
chn_pos_good = chn_pos[good_idx, :]

Expand Down
4 changes: 2 additions & 2 deletions pyprep/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def __init__(
raw.load_data()
self.raw = raw.copy()
self.ch_names = self.raw.ch_names
self.raw.pick_types(eeg=True, eog=False, meg=False, exclude=[])
self.raw.pick("eeg", exclude=[]) # include previously marked bads
self.bads_manual = raw.info["bads"]
self.ch_names_eeg = self.raw.ch_names
self.EEG = self.raw.get_data()
self.reference_channels = params["ref_chs"]
Expand All @@ -97,7 +98,6 @@ def __init__(
self.random_state = check_random_state(random_state)
self._extra_info = {}
self.matlab_strict = matlab_strict
self.bads_manual = raw.info["bads"]

def perform_reference(self, max_iterations=4):
"""Estimate the true signal mean and interpolate bad channels.
Expand Down
13 changes: 10 additions & 3 deletions pyprep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import math
from cmath import sqrt

import mne
import numpy as np
import scipy.interpolate
from mne.surface import _normalize_vectors
Expand Down Expand Up @@ -361,8 +360,16 @@ def _eeglab_interpolate_bads(raw):
"""
# Get the indices of good and bad EEG channels
eeg_chans = mne.pick_types(raw.info, eeg=True, exclude=[])
good_idx = mne.pick_types(raw.info, eeg=True, exclude="bads")
eeg_chans = np.array(
[idx for idx, typ in enumerate(raw.get_channel_types()) if typ == "eeg"]
)
good_idx = np.array(
[
idx
for idx, (ch, typ) in enumerate(zip(raw.ch_names, raw.get_channel_types()))
if (typ == "eeg") and (ch not in raw.info["bads"])
]
)
bad_idx = sorted(_set_diff(eeg_chans, good_idx))

# Get the spatial coordinates of the good and bad electrodes
Expand Down
9 changes: 6 additions & 3 deletions tests/test_prep_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# SPDX-License-Identifier: MIT

import matplotlib.pyplot as plt
import mne
import numpy as np
import pytest
import scipy.io as sio
Expand All @@ -17,7 +16,9 @@
@pytest.mark.usefixtures("raw", "montage")
def test_prep_pipeline(raw, montage):
"""Test prep pipeline."""
eeg_index = mne.pick_types(raw.info, eeg=True, eog=False, meg=False)
eeg_index = np.array(
[idx for idx, typ in enumerate(raw.get_channel_types()) if typ == "eeg"]
)
raw_copy = raw.copy()
ch_names = raw_copy.info["ch_names"]
ch_names_eeg = list(np.asarray(ch_names)[eeg_index])
Expand Down Expand Up @@ -258,7 +259,9 @@ def test_prep_pipeline_non_eeg(raw, montage):
@pytest.mark.usefixtures("raw", "montage")
def test_prep_pipeline_filter_kwargs(raw, montage):
"""Test prep pipeline with filter kwargs."""
eeg_index = mne.pick_types(raw.info, eeg=True, eog=False, meg=False)
eeg_index = np.array(
[idx for idx, typ in enumerate(raw.get_channel_types()) if typ == "eeg"]
)
raw_copy = raw.copy()
ch_names = raw_copy.info["ch_names"]
ch_names_eeg = list(np.asarray(ch_names)[eeg_index])
Expand Down

0 comments on commit 9d3f984

Please sign in to comment.