Skip to content

Commit

Permalink
replace pick_types with .pick
Browse files Browse the repository at this point in the history
  • Loading branch information
sappelhoff committed Nov 23, 2024
1 parent 7debc3f commit 8f54670
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 12 deletions.
3 changes: 1 addition & 2 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False)
raw.load_data()
self.raw_mne = raw.copy()
self.bad_by_manual = raw.info["bads"]
# Do not work on channels that were manually marked as bad
self.raw_mne.pick_types(eeg=True)
self.raw_mne.pick("eeg") # excludes bads
self.sample_rate = raw.info["sfreq"]
if do_detrend:
self.raw_mne._data = removeTrend(
Expand Down
5 changes: 3 additions & 2 deletions pyprep/ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Authors: The PyPREP developers
# SPDX-License-Identifier: MIT

import mne
import numpy as np
from mne.channels.interpolation import _make_interpolation_matrix
from mne.utils import ProgressBar, check_random_state, logger
Expand Down Expand Up @@ -138,7 +137,9 @@ def find_bad_by_ransac(
# Get all channel positions and the position subset of "clean channels"
# Exclude should be the bad channels from other methods
# That is, identify all bad channels by other means
good_idx = np.array([idx for idx, ch in enumerate(complete_chn_labs) if ch not in exclude])
good_idx = np.array(
[idx for idx, ch in enumerate(complete_chn_labs) if ch not in exclude]
)
n_chans_good = good_idx.shape[0]
chn_pos_good = chn_pos[good_idx, :]

Expand Down
4 changes: 2 additions & 2 deletions pyprep/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def __init__(
raw.load_data()
self.raw = raw.copy()
self.ch_names = self.raw.ch_names
self.raw.pick_types(eeg=True, eog=False, meg=False, exclude=[])
self.raw.pick("eeg", exclude=[]) # include previously marked bads
self.bads_manual = raw.info["bads"]
self.ch_names_eeg = self.raw.ch_names
self.EEG = self.raw.get_data()
self.reference_channels = params["ref_chs"]
Expand All @@ -97,7 +98,6 @@ def __init__(
self.random_state = check_random_state(random_state)
self._extra_info = {}
self.matlab_strict = matlab_strict
self.bads_manual = raw.info["bads"]

def perform_reference(self, max_iterations=4):
"""Estimate the true signal mean and interpolate bad channels.
Expand Down
13 changes: 10 additions & 3 deletions pyprep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import math
from cmath import sqrt

import mne
import numpy as np
import scipy.interpolate
from mne.surface import _normalize_vectors
Expand Down Expand Up @@ -361,8 +360,16 @@ def _eeglab_interpolate_bads(raw):
"""
# Get the indices of good and bad EEG channels
eeg_chans = mne.pick_types(raw.info, eeg=True, exclude=[])
good_idx = mne.pick_types(raw.info, eeg=True, exclude="bads")
eeg_chans = np.array(
[idx for idx, typ in enumerate(raw.get_channel_types()) if typ == "eeg"]
)
good_idx = np.array(
[
idx
for idx, (ch, typ) in enumerate(zip(raw.ch_names, raw.get_channel_types()))
if (typ == "eeg") and (ch not in raw.info["bads"])
]
)
bad_idx = sorted(_set_diff(eeg_chans, good_idx))

# Get the spatial coordinates of the good and bad electrodes
Expand Down
9 changes: 6 additions & 3 deletions tests/test_prep_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# SPDX-License-Identifier: MIT

import matplotlib.pyplot as plt
import mne
import numpy as np
import pytest
import scipy.io as sio
Expand All @@ -17,7 +16,9 @@
@pytest.mark.usefixtures("raw", "montage")
def test_prep_pipeline(raw, montage):
"""Test prep pipeline."""
eeg_index = mne.pick_types(raw.info, eeg=True, eog=False, meg=False)
eeg_index = np.array(
[idx for idx, typ in enumerate(raw.get_channel_types()) if typ == "eeg"]
)
raw_copy = raw.copy()
ch_names = raw_copy.info["ch_names"]
ch_names_eeg = list(np.asarray(ch_names)[eeg_index])
Expand Down Expand Up @@ -258,7 +259,9 @@ def test_prep_pipeline_non_eeg(raw, montage):
@pytest.mark.usefixtures("raw", "montage")
def test_prep_pipeline_filter_kwargs(raw, montage):
"""Test prep pipeline with filter kwargs."""
eeg_index = mne.pick_types(raw.info, eeg=True, eog=False, meg=False)
eeg_index = np.array(
[idx for idx, typ in enumerate(raw.get_channel_types()) if typ == "eeg"]
)
raw_copy = raw.copy()
ch_names = raw_copy.info["ch_names"]
ch_names_eeg = list(np.asarray(ch_names)[eeg_index])
Expand Down

0 comments on commit 8f54670

Please sign in to comment.