From 13c0a1a61248f08cf95a1b89fcd1d1ebbba68b53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20P=C3=A9rez-Velasco?= Date: Wed, 10 Jul 2024 13:04:56 +0200 Subject: [PATCH 1/2] Solved minor issue in string of msg for TFExtrasNotInstalled from tensorflow_integration.py Added version 1.3.1 to setup.py string --- medusa/tensorflow_integration.py | 6 +++--- setup.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/medusa/tensorflow_integration.py b/medusa/tensorflow_integration.py index bb6d6e7..3b85901 100644 --- a/medusa/tensorflow_integration.py +++ b/medusa/tensorflow_integration.py @@ -11,9 +11,9 @@ class TFExtrasNotInstalled(Exception): def __init__(self, msg=None): if msg is None: - msg = 'This functionality requires tensorflow extras. Reinstall ' - 'medusa-kernel using the following command "pip install ' - 'medusa-kernel[TF]' + msg = ('This functionality requires tensorflow extras. Reinstall ' + 'medusa-kernel using the following command "pip install ' + 'medusa-kernel[TF]') super().__init__(msg) diff --git a/setup.py b/setup.py index 9bc4781..ba1c609 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='medusa-kernel', packages=find_packages(), - version='1.3.0', + version='1.3.1', keywords=['Signal', 'Biosignal', 'EEG', 'BCI'], url='https://medusabci.com/', author='Eduardo Santamaría-Vázquez, ' From b1b63aa6db645b9e9f0dcf9d74f51cb2ab863b6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20P=C3=A9rez-Velasco?= Date: Wed, 10 Jul 2024 14:45:27 +0200 Subject: [PATCH 2/2] Adapted EEGSym model inside deep_learning_models.py and its usage on mi_paradigms.py to be consistent with new structure of all classes of mi_paradigms.py. --- medusa/bci/mi_paradigms.py | 194 ++++++++++++++++++++++++++------- medusa/deep_learning_models.py | 94 +++++++++++----- 2 files changed, 220 insertions(+), 68 deletions(-) diff --git a/medusa/bci/mi_paradigms.py b/medusa/bci/mi_paradigms.py index 0d34f30..caee389 100644 --- a/medusa/bci/mi_paradigms.py +++ b/medusa/bci/mi_paradigms.py @@ -454,7 +454,7 @@ class StandardFeatureExtraction(components.ProcessingMethod): it gets the raw epoch for each MI event. """ - def __init__(self, safe_copy=True): + def __init__(self, safe_copy=True, **kwargs): """ Class constructor. All parameters except "safe_copy" must be specified in each method. @@ -607,15 +607,15 @@ def transform_dataset(self, dataset, show_progress_bar=True, Length in samples of the sliding windows. Please note that the w_epoch_t parameter would not be used if the sliding window approach is used. If None, no sliding window would be applied. - baseline_mode : {'run', 'trial', None, 'sliding'} + baseline_mode : {'run', 'trial', 'sliding', None} The baseline_mode has an additional feature when sliding window is used: - "run": common baseline extracted from the start of the run. - "trial": common baseline for the trial, i.e., all the windows that belongs to an onset; extracted relative to the onset. - - None: no baseline. - "sliding": baseline is applied to each sliding window, and it is relative to the start of each sliding window. + - None: no baseline. Returns ------- @@ -667,7 +667,7 @@ def transform_dataset(self, dataset, show_progress_bar=True, desc='Extracting features') # Compute features - features = None + features_list = list() for run_counter, rec in enumerate(dataset.recordings): # Extract recording experiment and biosignal rec_exp = getattr(rec, dataset.experiment_att_key) @@ -724,9 +724,9 @@ def transform_dataset(self, dataset, show_progress_bar=True, target_fs=target_fs, concatenate_channels=concatenate_channels ) - features = np.concatenate((features, rec_feat), axis=0) \ - if features is not None else rec_feat - + features_list.append(rec_feat) + features = np.concatenate((features_list), axis=0) if ( + features_list) else None # Track experiment info for key, value in track_attributes.items(): if value['parent'] is None: @@ -1138,12 +1138,10 @@ def fit_dataset(self, dataset, get_training_accuracy=True, k_fold_iter = cu.k_fold_split(x, x_info['mi_labels'], k_fold) k_fold_acc = 0 for iter in k_fold_iter: - self.get_inst('clf_method').fit( - iter["x_train"], iter["y_train"]) - y_test_pred = self.get_inst('clf_method').predict( - iter["x_test"]) - y_test_prob = self.get_inst('clf_method').predict_proba( - iter["x_test"]) + clf = copy.deepcopy(self.get_inst('clf_method')) + clf.fit(iter["x_train"], iter["y_train"]) + y_test_pred = clf.predict(iter["x_test"]) + y_test_prob = clf.predict_proba(iter["x_test"]) fold_acc = np.sum(y_test_pred == iter["y_test"]) / \ len(iter["y_test"]) k_fold_acc += fold_acc @@ -1325,74 +1323,106 @@ class MIModelEEGSym(MIModel): def __init__(self): super().__init__() - def configure(self, cnn_n_cha=8, ch_lateral=3, fine_tuning=False, - shuffle_before_fit=True, validation_split=0.4, + def configure(self, p_filt_cutoff=(0.1, 45), w_epoch_t=(0, 2000), + w_baseline_t=(0, 2000), target_fs=128, cnn_n_cha=8, + ch_lateral=3, fine_tuning=False, validation_split=0.4, init_weights_path=None, gpu_acceleration=False, - augmentation=False): + augmentation=False, **kwargs): self.settings = { + # StandardPreprocessing + 'p_filt_cutoff': p_filt_cutoff, + # StandardFeatureExtraction + 'w_epoch_t': w_epoch_t, + 'baseline_mode': 'sliding', + 'w_baseline_t': w_baseline_t, + 'norm': 'z', + 'target_fs': target_fs, + 'concatenate_channels': False, + 'sliding_w_lims_t': None, + 'sliding_t_step': None, + 'sliding_win_len': None, + # EEGSym features 'cnn_n_cha': cnn_n_cha, 'ch_lateral': ch_lateral, 'fine_tuning': fine_tuning, 'augmentation': augmentation, - 'shuffle_before_fit': shuffle_before_fit, 'validation_split': validation_split, 'init_weights_path': init_weights_path, 'gpu_acceleration': gpu_acceleration } + self.settings = dict(self.settings, **kwargs) # Update state self.is_configured = True self.is_built = False self.is_fit = False def build(self): + """ Initializes the different methods that comprise the MIModelEEGSym + pipeline. + """ # Check errors if not self.is_configured: raise ValueError('Function configure must be called first!') # Only import deep learning models if necessary from medusa.deep_learning_models import EEGSym - # Preprocessing (bandpass IIR filter [0.5, 45] Hz + CAR) + # Preprocessing (default: bandpass IIR filter [0.5, 45] Hz + CAR) self.add_method('prep_method', - StandardPreprocessing(cutoff=49, btype='lowpass')) + StandardPreprocessing(cutoff=self.settings['p_filt_cutoff'])) # Feature extraction (epochs [0, 2000] ms + resampling to 128 Hz) - self.add_method('ext_method', StandardFeatureExtraction(safe_copy=True)) + self.add_method('ext_method', StandardFeatureExtraction(**self.settings)) # Feature classification clf = EEGSym( - input_time=2000, - fs=128, + input_time=int(self.settings['w_epoch_t'][1] - + self.settings['w_epoch_t'][0]), + fs=self.settings['target_fs'], n_cha=self.settings['cnn_n_cha'], ch_lateral=self.settings['ch_lateral'], filters_per_branch=24, scales_time=(125, 250, 500), dropout_rate=0.4, activation='elu', n_classes=2, - learning_rate=0.0001, + learning_rate=0.001, gpu_acceleration=self.settings['gpu_acceleration']) self.is_fit = False if self.settings['init_weights_path'] is not None: clf.load_weights(self.settings['init_weights_path']) self.channel_set = meeg.EEGChannelSet() - standard_lcha = ['F7', 'C3', 'Po3', 'Cz', 'Pz', 'F8', 'C4', 'Po4'] + standard_lcha = ['F7', 'C3', 'PO3', 'CZ', 'PZ', 'F8', 'C4', 'PO4'] self.channel_set.set_standard_montage(standard_lcha) + self.get_inst('prep_method').fit(fs=250, n_cha=8) self.is_fit = True + else: + self.is_fit = False self.add_method('clf_method', clf) # Update state self.is_built = True - # self.is_fit = False - def fit_dataset(self, dataset, continuous=False, **kwargs): + def fit_dataset(self, dataset, **kwargs): + """ Function to fit a dataset using MIModelCSP. + + Parameters + ------------- + dataset : MIDataset + MI dataset used for training. + **kwargs : dict + These parameters will be overwritten over self.settings. + + Returns + ------------- + assessment : dict + Dictionary containing the details of the training accuracy + estimation. + """ # Check errors if not self.is_built: raise ValueError('Function build must be called first!') + # Merge settings + settings = dict(self.settings, **kwargs) # Preprocessing dataset = self.get_inst('prep_method').fit_transform_dataset(dataset) # Extract features - # TODO: perform sliding window? if so, we need the lims - # TODO: hardcoded? - x, x_info = self.get_inst('ext_method').transform_dataset( - dataset, w_epoch_t=(0, 2000), baseline_mode="trial", - w_baseline_t=(0, 2000), norm="z", target_fs=128, - sliding_w_lims_t=None, sliding_t_step=None, sliding_win_len=None - ) + x, x_info = self.get_inst('ext_method').transform_dataset(dataset, + **settings) # Put channels in symmetric order x, _ = self.get_inst('clf_method').symmetric_channels( x, dataset.channel_set.l_cha) @@ -1401,12 +1431,12 @@ def fit_dataset(self, dataset, continuous=False, **kwargs): self.get_inst('clf_method').fit( x, x_info['mi_labels'], fine_tuning=self.settings['fine_tuning'], - shuffle_before_fit=self.settings['shuffle_before_fit'], validation_split=self.settings['validation_split'], augmentation=self.settings['augmentation'], **kwargs) + y_prob = self.get_inst('clf_method').predict_proba(x) - y_pred = y_prob.argmax(axis=-1) + y_pred = self.get_inst('clf_method').predict(x) # Accuracy accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred) @@ -1427,27 +1457,49 @@ def fit_dataset(self, dataset, continuous=False, **kwargs): return assessment def predict(self, times, signal, fs, channel_set, x_info, **kwargs): + """ Function to predict an individual signal in MIModelCSP. + + Parameters + -------------- + times : ndarray (n_samples,) + Timestamp array + signal: ndarray (n_samples x n_channels) + Signal data. + fs : int + Sampling frequency. + channel_set : EEGChannelSet or similar + Channel montage. + x_info : dict + Dictionary containing the trial "onsets" and "mi_labels". If the + latter is not specified, accuracy is not calculated. + **kwargs : dict + These parameters will be overwritten over self.settings. + + Returns + ------------- + decoding : dict + Dictionary containing the decoding. + """ # Check errors if not self.is_fit: raise ValueError('Function fit_dataset must be called first!') + # Merge settings + settings = dict(self.settings, **kwargs) # Check channel set - if self.channel_set.channels != channel_set.channels: + if self.channel_set.l_cha != channel_set.l_cha: warnings.warn('The channel set is not the same that was used to ' 'fit the model. Be careful!') # Preprocessing - signal = self.get_inst('prep_method').fit_transform_signal(signal, fs) + signal = self.get_inst('prep_method').transform_signal(signal) # Extract features - # TODO: hardcoded? x = self.get_inst('ext_method').transform_signal( times=times, signal=signal, fs=fs, onsets=x_info['onsets'], - w_epoch_t=(0, 2000), baseline_mode="trial", w_baseline_t=(0, 2000), - norm="z", target_fs=128 - ) + **settings) # Put channels in symmetric order x, _ = self.get_inst('clf_method').symmetric_channels( - x, self.channel_set.l_cha) + x, channel_set.l_cha) # Classification y_prob = self.get_inst('clf_method').predict_proba(x) @@ -1469,3 +1521,61 @@ def predict(self, times, signal, fs, channel_set, x_info, **kwargs): 'report': clf_report } return decoding + + def predict_dataset(self, dataset, **kwargs): + """ Function to predict a dataset using MIModelCSP. + + Parameters + ------------- + dataset : MIDataset + Test dataset. + **kwargs : dict + These parameters will be overwritten over self.settings. + + Returns + ------------- + decoding : dict + Dictionary containing the decoding. + """ + # Check errors + if not self.is_fit: + raise ValueError('Function fit_dataset must be called first!') + # Check channel set + if self.channel_set.l_cha != dataset.channel_set.l_cha: + warnings.warn('The channel set is not the same that was used to ' + 'fit the model. Be careful!') + # Merge settings + settings = dict(self.settings, **kwargs) + + # Preprocessing + dataset = self.get_inst('prep_method').transform_dataset(dataset, + **settings) + + # Extract features + x, x_info = self.get_inst('ext_method').transform_dataset( + dataset, **settings) + + # Put channels in symmetric order + x, _ = self.get_inst('clf_method').symmetric_channels(x, + dataset.channel_set.l_cha) + + # Classification + y_prob = self.get_inst('clf_method').predict_proba(x) + y_pred = y_prob.argmax(axis=-1) + + # Decoding + accuracy = None + clf_report = None + if x_info['mi_labels'] is not None: + accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred) + clf_report = classification_report(x_info['mi_labels'], y_pred, + output_dict=True) + decoding = { + 'x': x, + 'x_info': x_info, + 'y_prob': y_prob, + 'y_pred': y_pred, + 'accuracy': accuracy, + 'report': clf_report + } + return decoding diff --git a/medusa/deep_learning_models.py b/medusa/deep_learning_models.py index 03dba46..4c14869 100644 --- a/medusa/deep_learning_models.py +++ b/medusa/deep_learning_models.py @@ -1,6 +1,7 @@ # Built-in imports import warnings import os +import re # External imports import sklearn.utils as sk_utils @@ -10,6 +11,7 @@ from medusa import components from medusa import classification_utils from medusa import tensorflow_integration +from medusa.meeg import get_standard_montage # Extras if os.environ.get("MEDUSA_EXTRAS_GPU_TF") == "1": @@ -671,13 +673,11 @@ class EEGSym(components.ProcessingMethod): Variability in Motor Imagery Based BCIs with Deep Learning. IEEE Transactions on Neural Systems and Rehabilitation Engineering. """ - #TODO: Implement automatic ordering of channels - #TODO: Implement trial iterator and data augmentation def __init__(self, input_time=3000, fs=128, n_cha=8, filters_per_branch=24, scales_time=(500, 250, 125), dropout_rate=0.4, activation='elu', n_classes=2, learning_rate=0.001, ch_lateral=3, spatial_resnet_repetitions=1, residual=True, symmetric=True, - gpu_acceleration=None): + gpu_acceleration=False): # Super call super().__init__(fit=[], predict_proba=['y_pred']) @@ -1347,8 +1347,7 @@ def trial_iterator(self, X, y, batch_size=32, shuffle=True, ) return trial_iterator - def fit(self, X, y, fine_tuning=False, shuffle_before_fit=False, - augmentation=True, **kwargs): + def fit(self, X, y, fine_tuning=False, augmentation=True, **kwargs): """Fit the model. All additional keras parameters of class tensorflow.keras.Model will pass through. See keras documentation to know what can you do: https://keras.io/api/models/model_training_apis/. @@ -1373,10 +1372,6 @@ def fit(self, X, y, fine_tuning=False, shuffle_before_fit=False, fine_tuning: bool Set to True to use the default training parameters for fine tuning. False by default. - shuffle_before_fit: bool - If True, the data will be shuffled before training just once. Note - that if you use the keras native argument 'shuffle', the data is - shuffled each epoch. kwargs: Key-value arguments will be passed to the fit function of the model. This way, you can set your own training parameters using keras API. @@ -1387,10 +1382,6 @@ def fit(self, X, y, fine_tuning=False, shuffle_before_fit=False, warnings.warn('GPU acceleration is not available. The training ' 'time is drastically reduced with GPU.') - # Shuffle the data before fitting - if shuffle_before_fit: - X, y = sk_utils.shuffle(X, y) - # Creation of validation split val_split = kwargs['validation_split'] if \ 'validation_split' in kwargs else 0.4 @@ -1439,26 +1430,65 @@ def fit(self, X, y, fine_tuning=False, shuffle_before_fit=False, validation_data=(X[val_idx], y[val_idx]), **kwargs) - def symmetric_channels(self, X, channels): + def sort_channels_by_y(self, channels, front_to_back=True): + """Sort a list of EEG channels based on their y-coordinate position in the + 10-05 system. + + Parameters + ------------ + channels: list + List of EEG channel names. + front_to_back: bool, optional + Determines the sorting order, from front to back or from back to front. + Defaults to True. + + Returns + ------------ + sorted_channels: list + Sorted list of EEG channel names. + + """ + eeg_dictionary = get_standard_montage(standard='10-05', dim='3D', + coord_system='cartesian') + region_channels = {} + for channel in channels: + match = re.search(r'^[a-zA-Z]+(?=[\d|z|Z])', channel) + if match: + prefix = match.group() + if prefix not in region_channels: + region_channels[prefix] = [] + region_channels[prefix].append(channel) + + sorted_channels = sorted(channels, key=lambda x: (np.mean( + [eeg_dictionary[ch.upper()]['y'] for ch in + region_channels[re.search(r'^[a-zA-Z]+(?=[\d|z|Z])', x).group()]]), + abs(eeg_dictionary[ + x.upper()][ + 'x'])), + reverse=front_to_back) + return sorted_channels + def symmetric_channels(self, X, channels, front_to_back=True): """This function takes a set of channels and puts them in a symmetric input needed to apply EEGSym. """ - left = list() - right = list() - middle = list() + left, right, middle = [], [], [] for channel in channels: - if channel[-1].isnumeric(): - if int(channel[-1]) % 2 == 0: - right.append(channel) - else: - left.append(channel) + number = re.search(r"\d+", channel) + if number is not None: + (left if int(number[0]) % 2 else right).append(channel) else: middle.append(channel) + + left = self.sort_channels_by_y(left, front_to_back=front_to_back) + right = self.sort_channels_by_y(right, front_to_back=front_to_back) + middle = self.sort_channels_by_y(middle, front_to_back=front_to_back) + ordered_channels = left + middle + right - index_channels = [channels.index(channel) for channel in + index_channels = [list(channels).index(channel) for channel in ordered_channels] - return np.array(X)[:, :, index_channels], list(np.array(channels)[ - index_channels]) + + return np.take(X, index_channels, axis=-1), \ + np.take(channels, index_channels, axis=0) def predict_proba(self, X): """Model prediction scores for the given features. @@ -1475,6 +1505,18 @@ def predict_proba(self, X): # Predict with tf.device(tensorflow_integration.get_tf_device_name()): return self.model.predict(X) + def predict(self, X): + """Model prediction for the given features. + + Parameters + ---------- + X: np.ndarray + Feature matrix. If shape is [n_observ x n_samples x n_channels], + this matrix will be adapted to the input dimensions of EEG-Sym + [n_observ x n_samples x n_channels x 1] + """ + y_prob = self.predict_proba(X) + return y_prob.argmax(axis=-1) def to_pickleable_obj(self): # Parameters @@ -1500,7 +1542,7 @@ def to_pickleable_obj(self): @classmethod def from_pickleable_obj(cls, pickleable_obj): - pickleable_obj['kwargs']['gpu_acceleration'] = None + # pickleable_obj['kwargs']['gpu_acceleration'] = None model = cls(**pickleable_obj['kwargs']) model.model.set_weights(pickleable_obj['weights']) return model