Skip to content

Commit

Permalink
Merge pull request #29 from medusabci/solve_errors_tensorflow_integra…
Browse files Browse the repository at this point in the history
…tion

Solved minor issue in string of msg for TFExtrasNotInstalled from tensorflow_integration.py
  • Loading branch information
esantamariavazquez authored Jul 11, 2024
2 parents 868b625 + b1b63aa commit 00112a7
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 72 deletions.
194 changes: 152 additions & 42 deletions medusa/bci/mi_paradigms.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ class StandardFeatureExtraction(components.ProcessingMethod):
it gets the raw epoch for each MI event.
"""

def __init__(self, safe_copy=True):
def __init__(self, safe_copy=True, **kwargs):
"""
Class constructor. All parameters except "safe_copy" must be specified
in each method.
Expand Down Expand Up @@ -607,15 +607,15 @@ def transform_dataset(self, dataset, show_progress_bar=True,
Length in samples of the sliding windows. Please note that the
w_epoch_t parameter would not be used if the sliding window approach
is used. If None, no sliding window would be applied.
baseline_mode : {'run', 'trial', None, 'sliding'}
baseline_mode : {'run', 'trial', 'sliding', None}
The baseline_mode has an additional feature when sliding window is
used:
- "run": common baseline extracted from the start of the run.
- "trial": common baseline for the trial, i.e., all the windows
that belongs to an onset; extracted relative to the onset.
- None: no baseline.
- "sliding": baseline is applied to each sliding window,
and it is relative to the start of each sliding window.
- None: no baseline.
Returns
-------
Expand Down Expand Up @@ -667,7 +667,7 @@ def transform_dataset(self, dataset, show_progress_bar=True,
desc='Extracting features')

# Compute features
features = None
features_list = list()
for run_counter, rec in enumerate(dataset.recordings):
# Extract recording experiment and biosignal
rec_exp = getattr(rec, dataset.experiment_att_key)
Expand Down Expand Up @@ -724,9 +724,9 @@ def transform_dataset(self, dataset, show_progress_bar=True,
target_fs=target_fs,
concatenate_channels=concatenate_channels
)
features = np.concatenate((features, rec_feat), axis=0) \
if features is not None else rec_feat

features_list.append(rec_feat)
features = np.concatenate((features_list), axis=0) if (
features_list) else None
# Track experiment info
for key, value in track_attributes.items():
if value['parent'] is None:
Expand Down Expand Up @@ -1138,12 +1138,10 @@ def fit_dataset(self, dataset, get_training_accuracy=True,
k_fold_iter = cu.k_fold_split(x, x_info['mi_labels'], k_fold)
k_fold_acc = 0
for iter in k_fold_iter:
self.get_inst('clf_method').fit(
iter["x_train"], iter["y_train"])
y_test_pred = self.get_inst('clf_method').predict(
iter["x_test"])
y_test_prob = self.get_inst('clf_method').predict_proba(
iter["x_test"])
clf = copy.deepcopy(self.get_inst('clf_method'))
clf.fit(iter["x_train"], iter["y_train"])
y_test_pred = clf.predict(iter["x_test"])
y_test_prob = clf.predict_proba(iter["x_test"])
fold_acc = np.sum(y_test_pred == iter["y_test"]) / \
len(iter["y_test"])
k_fold_acc += fold_acc
Expand Down Expand Up @@ -1325,74 +1323,106 @@ class MIModelEEGSym(MIModel):
def __init__(self):
super().__init__()

def configure(self, cnn_n_cha=8, ch_lateral=3, fine_tuning=False,
shuffle_before_fit=True, validation_split=0.4,
def configure(self, p_filt_cutoff=(0.1, 45), w_epoch_t=(0, 2000),
w_baseline_t=(0, 2000), target_fs=128, cnn_n_cha=8,
ch_lateral=3, fine_tuning=False, validation_split=0.4,
init_weights_path=None, gpu_acceleration=False,
augmentation=False):
augmentation=False, **kwargs):
self.settings = {
# StandardPreprocessing
'p_filt_cutoff': p_filt_cutoff,
# StandardFeatureExtraction
'w_epoch_t': w_epoch_t,
'baseline_mode': 'sliding',
'w_baseline_t': w_baseline_t,
'norm': 'z',
'target_fs': target_fs,
'concatenate_channels': False,
'sliding_w_lims_t': None,
'sliding_t_step': None,
'sliding_win_len': None,
# EEGSym features
'cnn_n_cha': cnn_n_cha,
'ch_lateral': ch_lateral,
'fine_tuning': fine_tuning,
'augmentation': augmentation,
'shuffle_before_fit': shuffle_before_fit,
'validation_split': validation_split,
'init_weights_path': init_weights_path,
'gpu_acceleration': gpu_acceleration
}
self.settings = dict(self.settings, **kwargs)
# Update state
self.is_configured = True
self.is_built = False
self.is_fit = False

def build(self):
""" Initializes the different methods that comprise the MIModelEEGSym
pipeline.
"""
# Check errors
if not self.is_configured:
raise ValueError('Function configure must be called first!')
# Only import deep learning models if necessary
from medusa.deep_learning_models import EEGSym
# Preprocessing (bandpass IIR filter [0.5, 45] Hz + CAR)
# Preprocessing (default: bandpass IIR filter [0.5, 45] Hz + CAR)
self.add_method('prep_method',
StandardPreprocessing(cutoff=49, btype='lowpass'))
StandardPreprocessing(cutoff=self.settings['p_filt_cutoff']))
# Feature extraction (epochs [0, 2000] ms + resampling to 128 Hz)
self.add_method('ext_method', StandardFeatureExtraction(safe_copy=True))
self.add_method('ext_method', StandardFeatureExtraction(**self.settings))
# Feature classification
clf = EEGSym(
input_time=2000,
fs=128,
input_time=int(self.settings['w_epoch_t'][1] -
self.settings['w_epoch_t'][0]),
fs=self.settings['target_fs'],
n_cha=self.settings['cnn_n_cha'],
ch_lateral=self.settings['ch_lateral'],
filters_per_branch=24,
scales_time=(125, 250, 500),
dropout_rate=0.4,
activation='elu', n_classes=2,
learning_rate=0.0001,
learning_rate=0.001,
gpu_acceleration=self.settings['gpu_acceleration'])
self.is_fit = False
if self.settings['init_weights_path'] is not None:
clf.load_weights(self.settings['init_weights_path'])
self.channel_set = meeg.EEGChannelSet()
standard_lcha = ['F7', 'C3', 'Po3', 'Cz', 'Pz', 'F8', 'C4', 'Po4']
standard_lcha = ['F7', 'C3', 'PO3', 'CZ', 'PZ', 'F8', 'C4', 'PO4']
self.channel_set.set_standard_montage(standard_lcha)
self.get_inst('prep_method').fit(fs=250, n_cha=8)
self.is_fit = True
else:
self.is_fit = False
self.add_method('clf_method', clf)
# Update state
self.is_built = True
# self.is_fit = False

def fit_dataset(self, dataset, continuous=False, **kwargs):
def fit_dataset(self, dataset, **kwargs):
""" Function to fit a dataset using MIModelCSP.
Parameters
-------------
dataset : MIDataset
MI dataset used for training.
**kwargs : dict
These parameters will be overwritten over self.settings.
Returns
-------------
assessment : dict
Dictionary containing the details of the training accuracy
estimation.
"""
# Check errors
if not self.is_built:
raise ValueError('Function build must be called first!')
# Merge settings
settings = dict(self.settings, **kwargs)
# Preprocessing
dataset = self.get_inst('prep_method').fit_transform_dataset(dataset)
# Extract features
# TODO: perform sliding window? if so, we need the lims
# TODO: hardcoded?
x, x_info = self.get_inst('ext_method').transform_dataset(
dataset, w_epoch_t=(0, 2000), baseline_mode="trial",
w_baseline_t=(0, 2000), norm="z", target_fs=128,
sliding_w_lims_t=None, sliding_t_step=None, sliding_win_len=None
)
x, x_info = self.get_inst('ext_method').transform_dataset(dataset,
**settings)
# Put channels in symmetric order
x, _ = self.get_inst('clf_method').symmetric_channels(
x, dataset.channel_set.l_cha)
Expand All @@ -1401,12 +1431,12 @@ def fit_dataset(self, dataset, continuous=False, **kwargs):
self.get_inst('clf_method').fit(
x, x_info['mi_labels'],
fine_tuning=self.settings['fine_tuning'],
shuffle_before_fit=self.settings['shuffle_before_fit'],
validation_split=self.settings['validation_split'],
augmentation=self.settings['augmentation'],
**kwargs)

y_prob = self.get_inst('clf_method').predict_proba(x)
y_pred = y_prob.argmax(axis=-1)
y_pred = self.get_inst('clf_method').predict(x)

# Accuracy
accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred)
Expand All @@ -1427,27 +1457,49 @@ def fit_dataset(self, dataset, continuous=False, **kwargs):
return assessment

def predict(self, times, signal, fs, channel_set, x_info, **kwargs):
""" Function to predict an individual signal in MIModelCSP.
Parameters
--------------
times : ndarray (n_samples,)
Timestamp array
signal: ndarray (n_samples x n_channels)
Signal data.
fs : int
Sampling frequency.
channel_set : EEGChannelSet or similar
Channel montage.
x_info : dict
Dictionary containing the trial "onsets" and "mi_labels". If the
latter is not specified, accuracy is not calculated.
**kwargs : dict
These parameters will be overwritten over self.settings.
Returns
-------------
decoding : dict
Dictionary containing the decoding.
"""
# Check errors
if not self.is_fit:
raise ValueError('Function fit_dataset must be called first!')
# Merge settings
settings = dict(self.settings, **kwargs)
# Check channel set
if self.channel_set.channels != channel_set.channels:
if self.channel_set.l_cha != channel_set.l_cha:
warnings.warn('The channel set is not the same that was used to '
'fit the model. Be careful!')
# Preprocessing
signal = self.get_inst('prep_method').fit_transform_signal(signal, fs)
signal = self.get_inst('prep_method').transform_signal(signal)

# Extract features
# TODO: hardcoded?
x = self.get_inst('ext_method').transform_signal(
times=times, signal=signal, fs=fs, onsets=x_info['onsets'],
w_epoch_t=(0, 2000), baseline_mode="trial", w_baseline_t=(0, 2000),
norm="z", target_fs=128
)
**settings)

# Put channels in symmetric order
x, _ = self.get_inst('clf_method').symmetric_channels(
x, self.channel_set.l_cha)
x, channel_set.l_cha)

# Classification
y_prob = self.get_inst('clf_method').predict_proba(x)
Expand All @@ -1469,3 +1521,61 @@ def predict(self, times, signal, fs, channel_set, x_info, **kwargs):
'report': clf_report
}
return decoding

def predict_dataset(self, dataset, **kwargs):
""" Function to predict a dataset using MIModelCSP.
Parameters
-------------
dataset : MIDataset
Test dataset.
**kwargs : dict
These parameters will be overwritten over self.settings.
Returns
-------------
decoding : dict
Dictionary containing the decoding.
"""
# Check errors
if not self.is_fit:
raise ValueError('Function fit_dataset must be called first!')
# Check channel set
if self.channel_set.l_cha != dataset.channel_set.l_cha:
warnings.warn('The channel set is not the same that was used to '
'fit the model. Be careful!')
# Merge settings
settings = dict(self.settings, **kwargs)

# Preprocessing
dataset = self.get_inst('prep_method').transform_dataset(dataset,
**settings)

# Extract features
x, x_info = self.get_inst('ext_method').transform_dataset(
dataset, **settings)

# Put channels in symmetric order
x, _ = self.get_inst('clf_method').symmetric_channels(x,
dataset.channel_set.l_cha)

# Classification
y_prob = self.get_inst('clf_method').predict_proba(x)
y_pred = y_prob.argmax(axis=-1)

# Decoding
accuracy = None
clf_report = None
if x_info['mi_labels'] is not None:
accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred)
clf_report = classification_report(x_info['mi_labels'], y_pred,
output_dict=True)
decoding = {
'x': x,
'x_info': x_info,
'y_prob': y_prob,
'y_pred': y_pred,
'accuracy': accuracy,
'report': clf_report
}
return decoding
Loading

0 comments on commit 00112a7

Please sign in to comment.