diff --git a/medusa/bci/cvep_spellers.py b/medusa/bci/cvep_spellers.py index 74e4fdf..d7e4774 100644 --- a/medusa/bci/cvep_spellers.py +++ b/medusa/bci/cvep_spellers.py @@ -10,6 +10,7 @@ from medusa import meeg from medusa import spatial_filtering as sf from medusa import epoching as ep +from medusa import classification_utils as clf_utils import copy, warnings import itertools @@ -17,6 +18,8 @@ import numpy as np from tqdm import tqdm +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis + LFSR_PRIMITIVE_POLYNOMIALS = \ { 'base': { @@ -112,7 +115,7 @@ class for feature extraction and command decoding functions of the module def __init__(self, mode, paradigm_conf, commands_info, onsets, command_idx, unit_idx, level_idx, matrix_idx, cycle_idx, trial_idx, - cvep_model, spell_result, fps_resolution, spell_target=None, + spell_result, fps_resolution, spell_target=None, raster_events=None, **kwargs): # Check errors @@ -130,7 +133,6 @@ def __init__(self, mode, paradigm_conf, commands_info, onsets, command_idx, self.matrix_idx = matrix_idx self.cycle_idx = cycle_idx self.trial_idx = trial_idx - self.cvep_model = cvep_model self.spell_result = spell_result self.fps_resolution = fps_resolution self.spell_target = spell_target @@ -375,7 +377,559 @@ def custom_operations_on_recordings(self, recording): return recording +def decode_commands_from_events(event_scores, commands_info, event_run_idx, + event_trial_idx, event_cycle_idx): + """Command decoder for c-VEP-based spellers based on the bitwise + reconstruction paradigm (BWR), i.e., models that predict the command + sequence stimulus by stimulus. + + ToDo: allow multi-matrix paradigms with different number of levels. See + module erp_based_spellers for reference. + + Parameters + ---------- + event_scores : list or np.ndarray + Array with the score for each stimulation + commands_info : list or np.ndarray + Array containing the unified speller matrix structure with shape + [n_runs x n_matrices x n_units x n_groups x n_batches x + n_commands/batch]. All ERP-based speller paradigms can be adapted to + this format and use this function for command decoding. See + ERPSpellerData class for more info. + event_run_idx : list or numpy.ndarray [n_stim x 1] + Index of the run for each stimulation. This variable is automatically + retrieved by function extract_erp_features_from_dataset as part of + the track info dict. The run indexes must be related to + paradigm_conf, keeping the same order. Therefore, + paradigm_conf[np.unique(run_idx)[0]] must retrieve the paradigm + configuration of run 0. + event_trial_idx : list or numpy.ndarray [n_stim x 1] + Index of the trial for each stimulation. A trial represents + the selection of a final command. Depending on the number of levels, + the final selection takes N intermediate selections. + event_cycle_idx : list or numpy.ndarray [n_stim x 1] + Index of the sequence for each stimulation. A sequence + represents a round of stimulation: all commands have been + highlighted 1 time. This class support dynamic stopping in + different levels. + + Returns + ------- + selected_commands: list + Selected command for each trial considering all sequences of + stimulation. Each command is organized in an array [matrix_idx, + command_id]. Take into account that the command ids are unique for each + matrix, and therefore only the command of the last level should be + useful to take action. Shape [n_runs x n_trials x n_levels x 2] + selected_commands_per_cycle: list + Selected command for each trial and sequence of stimulation. The + fourth dimension of the array contains [matrix_idx, command_id]. To + calculate the command for each sequence, it takes into account the + scores of all the previous sequences as well. Shape [n_runs x + n_trials x n_levels x n_cycles x 2] + cmd_scores_per_cycle: + Scores for each command per cycle. Shape [n_runs x n_trials x + n_levels x n_cycles x n_commands x 1]. The score of each cycle + is calculated using all the previous cycles as well. + """ + # Decode commands + selected_commands = list() + selected_commands_per_cycle = list() + scores = list() + for r, run in enumerate(np.unique(event_run_idx)): + # Get run data + run_event_scores = event_scores[event_run_idx == run] + run_event_cycle_idx = event_cycle_idx[event_run_idx == run] + run_event_trial_idx = event_trial_idx[event_run_idx == run] + # Initialize + run_selected_commands = list() + run_selected_commands_per_cycle = list() + run_cmd_scores = list() + # Iterate trials + for t, trial in enumerate(np.unique(run_event_trial_idx)): + # Get trial data + trial_event_scores = run_event_scores[ + run_event_trial_idx == trial] + trial_event_cycle_idx = run_event_cycle_idx[ + run_event_trial_idx == trial] + # Initialize + trial_cmd_scores_per_cycle = list() + trial_selected_commands_per_cycle = list() + # Iterate cycles + for c, cycle in enumerate(np.unique(trial_event_cycle_idx)): + cycle_event_scores = trial_event_scores[ + trial_event_cycle_idx <= cycle] + # Get target sequences + cmd_ids = list() + cmd_seqs = list() + for cmd_id, cmd_info in commands_info[r][0].items(): + cmd_ids.append(cmd_id) + cmd_seqs.append(cmd_info['sequence'] * (c+1)) + # Calculate correlations to all commands + corr_scores = np.abs( + np.corrcoef(cycle_event_scores, cmd_seqs)[0, 1:]) + # Save trial data + cmd_id = cmd_ids[np.argmax(corr_scores)] + trial_cmd_scores_per_cycle.append(corr_scores) + trial_selected_commands_per_cycle.append([0, cmd_id]) + # Save run data + # ToDo: add another loop for levels + run_selected_commands.append( + [trial_selected_commands_per_cycle[-1]]) + run_selected_commands_per_cycle.append( + [trial_selected_commands_per_cycle]) + run_cmd_scores.append( + [trial_cmd_scores_per_cycle]) + # Save run data + selected_commands.append(run_selected_commands) + selected_commands_per_cycle.append(run_selected_commands_per_cycle) + scores.append(run_cmd_scores) + + return selected_commands, selected_commands_per_cycle, scores + + +def command_decoding_accuracy_per_cycle(selected_commands_per_cycle, + target_commands): + """ + Computes the accuracy of the selected sequence of targets given the + target + + Parameters + ---------- + selected_commands_per_cycle: list + List with the spell result per sequence as given by function + decode_commands. Shape [n_runs x n_trials x n_levels x n_cycles x 2] + target_commands: list + Target commands. Each position contains the matrix index and command + id per level that identifies the target command of the trial. Shape + [n_runs x n_trials x n_levels x 2] + + Returns + ------- + acc_per_cycle : float + Accuracy of the command decoding stage for each number of cycles + considered in the analysis. Shape [n_sequences] + """ + # Check errors + selected_commands_per_cycle = list(selected_commands_per_cycle) + target_commands = list(target_commands) + if len(selected_commands_per_cycle) != len(target_commands): + raise ValueError('Parameters selected_commands_per_seq and spell_target' + 'must have the same length.') + + # Compute accuracy per sequence + bool_result_per_seq = [] + n_seqs = [] + for r in range(len(selected_commands_per_cycle)): + r_sel_cmd_per_seq = selected_commands_per_cycle[r] + r_spell_target = target_commands[r] + for t in range(len(r_sel_cmd_per_seq)): + t_sel_cmd_per_seq = r_sel_cmd_per_seq[t] + t_spell_target = r_spell_target[t] + t_bool_result_per_seq = [] + t_n_seqs = [] + for t in range(len(t_sel_cmd_per_seq)): + l_sel_cmd_per_seq = t_sel_cmd_per_seq[t] + l_spell_target = t_spell_target[t] + t_bool_result_per_seq.append(list()) + t_n_seqs.append(len(l_sel_cmd_per_seq)) + for s in range(len(l_sel_cmd_per_seq)): + s_sel_cmd_per_seq = l_sel_cmd_per_seq[s] + t_bool_result_per_seq[t].append(l_spell_target == + s_sel_cmd_per_seq) + + # Calculate the trial result per seq (all levels must be correct) + t_n_levels = len(t_sel_cmd_per_seq) + t_max_n_seqs = np.max(t_n_seqs) + t_acc_per_seq = np.empty((t_max_n_seqs, t_n_levels)) + t_acc_per_seq[:] = np.nan + for t in range(t_n_levels): + t_acc_per_seq[:t_n_seqs[t], t] = t_bool_result_per_seq[t] + bool_result_per_seq.append(np.all(t_acc_per_seq, axis=1)) + n_seqs.append(t_max_n_seqs) + + # Calculate the accuracy per number of sequences considered in the analysis + max_n_seqs = np.max(n_seqs) + n_trials = len(bool_result_per_seq) + acc_per_seq = np.empty((max_n_seqs, n_trials)) + acc_per_seq[:] = np.nan + for t in range(n_trials): + acc_per_seq[:n_seqs[t], t] = bool_result_per_seq[t] + + return np.nanmean(acc_per_seq, axis=1) + + # ---------------------------------- MODELS ---------------------------------- # +class CVEPSpellerModel(components.Algorithm): + + def __init__(self): + """Class constructor + """ + super().__init__(fit_dataset=['spell_target', + 'spell_result_per_cycles', + 'spell_acc_per_cycles'], + predict=['spell_result', + 'spell_result_per_cycles']) + # Settings + self.settings = None + self.channel_set = None + self.configure() + # Configuration + self.is_configured = False + self.is_built = False + self.is_fit = False + + @abstractmethod + def configure(self, **kwargs): + """This function must be used to configure the model before calling + build method. Class attribute settings attribute must be set with a dict + """ + # Update state + self.is_configured = True + self.is_built = False + self.is_fit = False + + @abstractmethod + def build(self, *args, **kwargs): + """This function builds the model, adding all the processing methods + to the pipeline. It must be called after configure. + """ + # Check errors + if not self.is_configured: + raise ValueError('Function configure must be called first!') + # Update state + self.is_built = True + self.is_fit = False + + @abstractmethod + def fit_dataset(self, dataset, **kwargs): + pass + + @abstractmethod + def predict_dataset(self, dataset, **kwargs): + pass + + @abstractmethod + def predict(self, times, signal, fs, channel_set, x_info, **kwargs): + pass + + +class CMDModelBWRLDA(CVEPSpellerModel): + + """Class that uses the bitwise reconstruction (BWR) paradigm with an LDA + classifier""" + def __int__(self): + super().__init__() + + def configure(self, bpf=(7, (1.0, 60.0)), notch=(7, (49.0, 51.0)), + w_epoch_t=(0, 500), target_fs=None): + self.settings = { + 'bpf': bpf, + 'notch': notch, + 'w_epoch_t': w_epoch_t, + 'target_fs': target_fs + } + # Update state + self.is_configured = True + self.is_built = False + self.is_fit = False + + def build(self): + # Preprocessing + bpf = self.settings['bpf'] + notch = self.settings['notch'] + if notch is not None: + self.add_method('prep_method', StandardPreprocessing( + bpf_order=bpf[0], bpf_cutoff=bpf[1], + notch_order=notch[0], notch_cutoff=notch[1])) + else: + self.add_method('prep_method', StandardPreprocessing( + bpf_order=bpf[0], bpf_cutoff=bpf[1], + notch_order=None, notch_cutoff=None)) + + # Feature extraction + self.add_method('ext_method', BWRFeatureExtraction( + w_epoch_t=self.settings['w_epoch_t'], + target_fs=self.settings['target_fs'], + w_baseline_t=(-250, 0), norm='z', + concatenate_channels=True, safe_copy=True)) + + # Feature classification + clf = components.ProcessingClassWrapper( + LinearDiscriminantAnalysis(solver='eigen', shrinkage='auto'), + fit=[], predict_proba=['y_pred'] + ) + self.add_method('clf_method', clf) + # Update state + self.is_built = True + self.is_fit = False + + def check_predict_feasibility_signal(self, times, cycle_onsets, fps, + code_len, fs): + return self.get_inst('ext_method').check_predict_feasibility_signal( + times, cycle_onsets, fps, code_len, fs) + + def fit_dataset(self, dataset, show_progress_bar=False): + # Check errors + if not self.is_built: + raise ValueError('The model must be built first!') + # Preprocessing + dataset = self.get_inst('prep_method').fit_transform_dataset( + dataset, show_progress_bar=show_progress_bar) + # Extract features + x, x_info = self.get_inst('ext_method').transform_dataset( + dataset, show_progress_bar=show_progress_bar) + # Classification + self.get_inst('clf_method').fit(x, x_info['event_cvep_labels']) + # Save info + self.channel_set = dataset.channel_set + # Update state + self.is_fit = True + + def predict_dataset(self, dataset, show_progress_bar=False): + # Check errors + if not self.is_fit: + raise ValueError('The model must be fitted first!') + # Preprocessing + dataset = self.get_inst('prep_method').fit_transform_dataset( + dataset, show_progress_bar=show_progress_bar) + # Extract features + x, x_info = self.get_inst('ext_method').transform_dataset( + dataset, show_progress_bar=show_progress_bar) + # Predict + y_pred = self.get_inst('clf_method').predict_proba(x)[:, 1] + # Command decoding + sel_cmd, sel_cmd_per_cycle, scores = decode_commands_from_events( + event_scores=y_pred, + commands_info=x_info['commands_info'], + event_run_idx=x_info['event_run_idx'], + event_trial_idx=x_info['event_trial_idx'], + event_cycle_idx=x_info['event_cycle_idx'] + ) + # Spell accuracy + cmd_assessment = None + if dataset.experiment_mode == 'train': + # Spell accuracy per seq + spell_acc_per_cycle = command_decoding_accuracy_per_cycle( + sel_cmd_per_cycle, + x_info['spell_target'] + ) + cmd_assessment = { + 'x': x, + 'x_info': x_info, + 'y_pred': y_pred, + 'spell_result': sel_cmd, + 'spell_result_per_cycle': sel_cmd_per_cycle, + 'spell_acc_per_cycle': spell_acc_per_cycle + } + return sel_cmd, cmd_assessment + + def predict(self, times, signal, fs, channel_set, x_info, **kwargs): + # Check errors + if not self.is_fit: + raise ValueError('The model must be fitted first!') + # Check channel set + if self.channel_set != channel_set: + warnings.warn( + 'The channel set is not the same that was used to fit the ' + 'model. Be careful!') + # Pre-processing + signal = self.get_inst('prep_method').transform_signal(signal=signal) + # Extract features + x = self.get_inst('ext_method').transform_signal( + times, signal, fs, x_info['cycle_onsets'], + x_info['fps'], x_info['code_len']) + # Predict + y_pred = self.get_inst('clf_method').predict_proba(x)[:, 1] + # Get run_idx, trial_idx and cycle_idx per stimulation event + event_run_idx = np.repeat(x_info['run_idx'], x_info['code_len']) + event_trial_idx = np.repeat(x_info['trial_idx'], x_info['code_len']) + event_cycle_idx = np.repeat(x_info['cycle_idx'], x_info['code_len']) + # Command decoding + sel_cmd, sel_cmd_per_cycle, scores = decode_commands_from_events( + event_scores=y_pred, + commands_info=x_info['commands_info'], + event_run_idx=event_run_idx, + event_trial_idx=event_trial_idx, + event_cycle_idx=event_cycle_idx + ) + return sel_cmd, sel_cmd_per_cycle, scores + + +class CMDModelBWREEGInception(CVEPSpellerModel): + """Class that uses the bitwise reconstruction (BWR) paradigm with an + EEG-Inception model """ + def __int__(self): + super().__init__() + + def configure(self, bpf=(7, (1.0, 60.0)), notch=(7, (49.0, 51.0)), + w_epoch_t=(0, 500), target_fs=200, n_cha=16, + filters_per_branch=12, scales_time=(250, 125, 62.5), + dropout_rate=0.15, activation='elu', n_classes=2, + learning_rate=0.001, batch_size=256, + max_training_epochs=500, validation_split=0.1, + shuffle_before_fit=True): + self.settings = { + 'bpf': bpf, + 'notch': notch, + 'w_epoch_t': w_epoch_t, + 'target_fs': target_fs, + 'n_cha': n_cha, + 'filters_per_branch': filters_per_branch, + 'scales_time': scales_time, + 'dropout_rate': dropout_rate, + 'activation': activation, + 'n_classes': n_classes, + 'learning_rate': learning_rate, + 'batch_size': batch_size, + 'max_training_epochs': max_training_epochs, + 'validation_split': validation_split, + 'shuffle_before_fit': shuffle_before_fit + } + # Update state + self.is_configured = True + self.is_built = False + self.is_fit = False + + def build(self): + # Preprocessing + bpf = self.settings['bpf'] + notch = self.settings['notch'] + if notch is not None: + self.add_method('prep_method', StandardPreprocessing( + bpf_order=bpf[0], bpf_cutoff=bpf[1], + notch_order=notch[0], notch_cutoff=notch[1])) + else: + self.add_method('prep_method', StandardPreprocessing( + bpf_order=bpf[0], bpf_cutoff=bpf[1], + notch_order=None, notch_cutoff=None)) + + # Feature extraction + self.add_method('ext_method', BWRFeatureExtraction( + w_epoch_t=self.settings['w_epoch_t'], + target_fs=self.settings['target_fs'], + w_baseline_t=(-250, 0), norm='z', + concatenate_channels=False, safe_copy=True)) + + # Feature classification + from medusa.deep_learning_models import EEGInceptionv1 + input_time = \ + self.settings['w_epoch_t'][1] - self.settings['w_epoch_t'][0] + clf = EEGInceptionv1( + input_time=input_time, + fs=self.settings['target_fs'], + n_cha=self.settings['n_cha'], + filters_per_branch=self.settings['filters_per_branch'], + scales_time=self.settings['scales_time'], + dropout_rate=self.settings['dropout_rate'], + activation=self.settings['activation'], + n_classes=self.settings['n_classes'], + learning_rate=self.settings['learning_rate']) + self.add_method('clf_method', clf) + + # Update state + self.is_built = True + self.is_fit = False + + def check_predict_feasibility_signal(self, times, cycle_onsets, fps, + code_len, fs): + return self.get_inst('ext_method').check_predict_feasibility_signal( + times, cycle_onsets, fps, code_len, fs) + + def fit_dataset(self, dataset, show_progress_bar=False): + # Check errors + if not self.is_built: + raise ValueError('The model must be built first!') + # Preprocessing + dataset = self.get_inst('prep_method').fit_transform_dataset( + dataset, show_progress_bar=show_progress_bar) + # Extract features + x, x_info = self.get_inst('ext_method').transform_dataset( + dataset, show_progress_bar=show_progress_bar) + # Classification + self.get_inst('clf_method').fit( + x, x_info['event_cvep_labels'], + shuffle_before_fit=self.settings['shuffle_before_fit'], + epochs=self.settings['max_training_epochs'], + validation_split=self.settings['validation_split'], + batch_size=self.settings['batch_size']) + # Save info + self.channel_set = dataset.channel_set + # Update state + self.is_fit = True + + def predict_dataset(self, dataset, show_progress_bar=False): + # Check errors + if not self.is_fit: + raise ValueError('The model must be fitted first!') + # Preprocessing + dataset = self.get_inst('prep_method').fit_transform_dataset( + dataset, show_progress_bar=show_progress_bar) + # Extract features + x, x_info = self.get_inst('ext_method').transform_dataset( + dataset, show_progress_bar=show_progress_bar) + # Predict + y_pred = self.get_inst('clf_method').predict_proba(x) + y_pred = clf_utils.categorical_labels(y_pred) + # Command decoding + sel_cmd, sel_cmd_per_cycle, scores = decode_commands_from_events( + event_scores=y_pred, + commands_info=x_info['commands_info'], + event_run_idx=x_info['event_run_idx'], + event_trial_idx=x_info['event_trial_idx'], + event_cycle_idx=x_info['event_cycle_idx'] + ) + # Spell accuracy + cmd_assessment = None + if dataset.experiment_mode == 'train': + # Spell accuracy per seq + spell_acc_per_cycle = command_decoding_accuracy_per_cycle( + sel_cmd_per_cycle, + x_info['spell_target'] + ) + cmd_assessment = { + 'x': x, + 'x_info': x_info, + 'y_pred': y_pred, + 'spell_result': sel_cmd, + 'spell_result_per_cycle': sel_cmd_per_cycle, + 'spell_acc_per_cycle': spell_acc_per_cycle + } + return sel_cmd, cmd_assessment + + def predict(self, times, signal, fs, channel_set, x_info, **kwargs): + # Check errors + if not self.is_fit: + raise ValueError('The model must be fitted first!') + # Check channel set + if self.channel_set != channel_set: + warnings.warn( + 'The channel set is not the same that was used to fit the ' + 'model. Be careful!') + # Pre-processing + signal = self.get_inst('prep_method').transform_signal(signal=signal) + # Extract features + x = self.get_inst('ext_method').transform_signal( + times, signal, fs, x_info['cycle_onsets'], + x_info['fps'], x_info['code_len']) + # Predict + y_pred = self.get_inst('clf_method').predict_proba(x) + y_pred = clf_utils.categorical_labels(y_pred) + # Get run_idx, trial_idx and cycle_idx per stimulation event + event_run_idx = np.repeat(x_info['run_idx'], x_info['code_len']) + event_trial_idx = np.repeat(x_info['trial_idx'], x_info['code_len']) + event_cycle_idx = np.repeat(x_info['cycle_idx'], x_info['code_len']) + # Command decoding + sel_cmd, sel_cmd_per_cycle, scores = decode_commands_from_events( + event_scores=y_pred, + commands_info=x_info['commands_info'], + event_run_idx=event_run_idx, + event_trial_idx=event_trial_idx, + event_cycle_idx=event_cycle_idx + ) + return sel_cmd, sel_cmd_per_cycle, scores + + class CVEPModelCircularShifting(components.Algorithm): def __init__(self, bpf=[[7, (1.0, 30.0)]], notch=[7, (49.0, 51.0)], @@ -388,22 +942,27 @@ def __init__(self, bpf=[[7, (1.0, 30.0)]], notch=[7, (49.0, 51.0)], self.add_method('prep_method', StandardPreprocessing( bpf_order=bpf[0][0], bpf_cutoff=bpf[0][1], notch_order=notch[0], notch_cutoff=notch[1])) + max_order = max(bpf[0][0], notch[0]) else: self.add_method('prep_method', StandardPreprocessing( bpf_order=bpf[0][0], bpf_cutoff=bpf[0][1], notch_order=None, notch_cutoff=None)) + max_order = bpf[0][0] else: filter_bank = [] + max_order = 0 for i in range(len(bpf)): filter_bank.append({ 'order': bpf[i][0], 'cutoff': bpf[i][1], 'btype': 'bandpass' }) + max_order = bpf[i][0] if bpf[i][0] > max_order else max_order if notch is not None: self.add_method('prep_method', FilterBankPreprocessing( filter_bank=filter_bank, notch_order=notch[0], notch_cutoff=notch[1])) + max_order = max(max_order, notch[0]) else: self.add_method('prep_method', FilterBankPreprocessing( filter_bank=filter_bank, notch_order=None, @@ -412,7 +971,8 @@ def __init__(self, bpf=[[7, (1.0, 30.0)]], notch=[7, (49.0, 51.0)], # Feature extraction and classification (circular shifting) self.add_method('clf_method', CircularShiftingClassifier( art_rej=art_rej, - correct_raster_latencies=correct_raster_latencies + correct_raster_latencies=correct_raster_latencies, + extra_epoch_samples=3*max_order )) # Early stopping @@ -425,7 +985,7 @@ def check_predict_feasibility_signal(self, times, onsets, fs): return self.get_inst('clf_method')._is_predict_feasible_signal( times, onsets, fs) - def fit_dataset(self, dataset, **kwargs): + def fit_dataset(self, dataset, roll_targets=False, **kwargs): # Safe copy data = copy.deepcopy(dataset) @@ -438,8 +998,8 @@ def fit_dataset(self, dataset, **kwargs): # Feature extraction and classification fitted_info = self.get_inst('clf_method').fit_dataset( dataset=data, - std_epoch_rejection=None, - show_progress_bar=True + show_progress_bar=True, + roll_targets=roll_targets ) return fitted_info @@ -646,6 +1206,276 @@ def transform_dataset(self, dataset: CVEPSpellerDataset, return dataset +class BWRFeatureExtraction(components.ProcessingMethod): + """Feature extraction method designed to extract event-wise epochs from + c-VEP stimulation paradigms to perform bitwise reconstruction (BWR) + """ + + def __init__(self, w_epoch_t=(0, 500), target_fs=20, + w_baseline_t=(-250, 0), norm='z', + concatenate_channels=True, safe_copy=True): + """Class constructor + + w_epoch_t : list + Temporal window in ms for each epoch relative to the event onset + (e.g., [0, 1000]) + target_fs : float of None + Target sample rate of each epoch. If None, all the recordings must + have the same sample rate, so it is strongly recommended to set this + parameter to a suitable value to avoid problems and save time + w_baseline_t : list + Temporal window in ms to be used for baseline normalization for each + epoch relative to the event onset (e.g., [-250, 0]) + norm : str {'z'|'dc'} + Type of baseline normalization. Set to 'z' for Z-score normalization + or 'dc' for DC normalization + concatenate_channels : bool + This parameter controls the shape of the feature array. If True, all + channels will be concatenated, returning an array of shape [n_events + x (samples x channels)]. If false, the array will have shape + [n_events x samples x channels] + safe_copy : bool + Makes a safe copy of the signal to avoid changing the original + samples due to references + """ + super().__init__(transform_signal=['x'], + transform_dataset=['x', 'x_info']) + self.w_epoch_t = w_epoch_t + self.target_fs = target_fs + self.w_baseline_t = w_baseline_t + self.norm = norm + self.concatenate_channels = concatenate_channels + self.safe_copy = safe_copy + + @staticmethod + def generate_bit_wise_onsets(cycle_onsets, frames_per_second, code_len): + # Generate bit-wise onsets + onsets = [] + for o in cycle_onsets: + onsets += np.linspace( + o, o + (code_len - 1) / frames_per_second, code_len).astype( + float).tolist() + return onsets + + def check_predict_feasibility_signal(self, times, cycle_onsets, fps, + code_len, fs): + # Generate bit-wise onsets, because, for BWR methods, we need w_epoch_t + # ms after the last stimulus of the sequence + bit_wise_onsets = self.generate_bit_wise_onsets( + cycle_onsets, fps, code_len) + check = ep.check_epochs_feasibility( + times, bit_wise_onsets, fs, self.w_epoch_t) + return True if check == 'ok' else False + + def transform_signal(self, times, signal, fs, cycle_onsets, fps, code_len): + """Function to extract VEP features from raw signal. It returns a 3D + feature array with shape [n_events x n_samples x n_channels]. This + function does not track any other attributes. Use for online processing + and custom higher level functions. + + Parameters + ---------- + times : list or numpy.ndarray + 1D numpy array [n_samples]. Timestamps of each sample. If they + are not available, generate them artificially. Nevertheless, + all signals and events must have the same temporal origin + signal : list or numpy.ndarray + 2D numpy array [n_samples x n_channels]. EEG samples (the units + should be defined using kwargs) + fs : int or float + Sample rate of the recording. + cycle_onsets : list or numpy.ndarray [n_cycles x 1] + Timestamps indicating the start of each stimulation cycle + fps: int + Frames per second of the screen that presents the stimulation + code_len: int + Length of the c-VEP codes + + Returns + ------- + features : np.ndarray [n_events x n_samples x n_channels] + Feature array with the epochs of signal + """ + # Avoid changes in the original signal (this may not be necessary) + if self.safe_copy: + signal = signal.copy() + # Get event-wise onsets + onsets = self.generate_bit_wise_onsets(cycle_onsets, fps, code_len) + # Extract features + features = mds.get_epochs_of_events(timestamps=times, signal=signal, + onsets=onsets, fs=fs, + w_epoch_t=self.w_epoch_t, + w_baseline_t=self.w_baseline_t, + norm=self.norm) + # Resample each epoch to the target frequency + if self.target_fs is not None: + if self.target_fs > fs: + raise warnings.warn('Target fs is greater than data fs') + features = mds.resample_epochs(features, + self.w_epoch_t, + self.target_fs) + # Reshape epochs and concatenate the channels + if self.concatenate_channels: + features = np.squeeze(features.reshape((features.shape[0], + features.shape[1] * + features.shape[2], 1))) + return features + + def transform_dataset(self, dataset, show_progress_bar=True): + """High level function to easily extract features from EEG recordings + and save useful info for later processing. Nevertheless, the provided + functionality has several limitations, and it will not be suitable for + all cases and processing pipelines. If it does not fit your needs, + create a custom function iterating the recordings and using + extract_erp_features, a much more low-level and general function. This + function does not apply any preprocessing to the signals, this must + be done before. + + Parameters + ---------- + dataset: ERPSpellerDataset + List of data_structures.Recordings or data_structures.Dataset. If this + parameter is a list of recordings, the consistency of the dataset will + be checked. Otherwise, if the parameter is a dataset, this function + assumes that the consistency is already checked + show_progress_bar: bool + Show progress bar + + Returns + ------- + features : numpy.ndarray + Array with the biosignal samples arranged in epochs + track_info : dict + Dictionary with tracked information across all recordings + + """ + # Avoid changes in the original recordings (this may not be necessary) + if self.safe_copy: + dataset = copy.deepcopy(dataset) + # Avoid consistency problems + if dataset.fs is None and self.target_fs is None: + raise ValueError('The consistency of the features is not assured ' + 'since dataset.fs and target_fs are both None. ' + 'Specify one of these parameters') + + # Additional track attributes + track_attributes = dataset.track_attributes + track_attributes['run_idx'] = { + 'track_mode': 'concatenate', + 'parent': dataset.experiment_att_key + } + track_attributes['event_run_idx'] = { + 'track_mode': 'concatenate', + 'parent': dataset.experiment_att_key + } + track_attributes['event_trial_idx'] = { + 'track_mode': 'concatenate', + 'parent': dataset.experiment_att_key + } + track_attributes['event_cycle_idx'] = { + 'track_mode': 'concatenate', + 'parent': dataset.experiment_att_key + } + if dataset.experiment_mode == 'train': + track_attributes['event_cvep_labels'] = { + 'track_mode': 'concatenate', + 'parent': dataset.experiment_att_key + } + + # Initialization + features = None + track_info = dict() + for key, value in track_attributes.items(): + if value['track_mode'] == 'append': + track_info[key] = list() + elif value['track_mode'] == 'concatenate': + track_info[key] = None + else: + raise ValueError('Unknown track mode') + + # Init progress bar + pbar = None + if show_progress_bar: + pbar = tqdm(total=len(dataset.recordings), + desc='Extracting features') + + # Compute features + run_counter = 0 + trial_counter = 0 + for rec in dataset.recordings: + # Extract recording experiment and biosignal + rec_exp = getattr(rec, dataset.experiment_att_key) + rec_sig = getattr(rec, dataset.biosignal_att_key) + + # Get features + rec_feat = self.transform_signal( + times=rec_sig.times, + signal=rec_sig.signal, + fs=rec_sig.fs, + cycle_onsets=rec_exp.onsets, + fps=rec_exp.fps_resolution, + code_len=len(rec_exp.commands_info[0]['0']['sequence']) + ) + features = np.concatenate((features, rec_feat), axis=0) \ + if features is not None else rec_feat + + # Special attributes that need tracking across runs to assure the + # consistency of the dataset + rec_exp.run_idx = run_counter * np.ones_like(rec_exp.trial_idx) + rec_exp.trial_idx = trial_counter + np.array(rec_exp.trial_idx) + + # Event tracking attributes + seq_len = len(list(rec_exp.commands_info[0].values())[0][ + 'sequence']) + rec_exp.event_run_idx = np.repeat(rec_exp.run_idx, seq_len) + rec_exp.event_trial_idx = np.repeat(rec_exp.trial_idx, seq_len) + rec_exp.event_cycle_idx = np.repeat(rec_exp.cycle_idx, seq_len) + + # Get labels of the individual events as required in BWR method + if dataset.experiment_mode == 'train': + rec_exp.event_cvep_labels = np.array([]) + for i, t in enumerate(np.unique(rec_exp.trial_idx)): + # ToDo: add another loop for levels + target = rec_exp.spell_target[i][0] + cmd_mtx = target[0] + cmd_id = target[1] + cmd_seq = rec_exp.commands_info[cmd_mtx][cmd_id]['sequence'] + n_cycles = np.array(rec_exp.cycle_idx)[ + rec_exp.trial_idx == t][-1] + 1 + rec_exp.event_cvep_labels = np.concatenate( + (rec_exp.event_cvep_labels, cmd_seq * int(n_cycles)), + axis=0) + + # Update counters of special attributes + run_counter += 1 + trial_counter += np.unique(rec_exp.trial_idx).shape[0] + + # Track experiment info + for key, value in track_attributes.items(): + if value['parent'] is None: + parent = rec + else: + parent = rec + for p in value['parent'].split('.'): + parent = getattr(parent, p) + att = getattr(parent, key) + if value['track_mode'] == 'append': + track_info[key].append(att) + elif value['track_mode'] == 'concatenate': + track_info[key] = np.concatenate( + (track_info[key], att), axis=0 + ) if track_info[key] is not None else att + else: + raise ValueError('Unknown track mode') + + if show_progress_bar: + pbar.update(1) + if show_progress_bar: + pbar.close() + + return features, track_info + + class FilterBankPreprocessing(components.ProcessingMethod): """Just the common preprocessing applied in c-VEP-based spellers. Simple, quick and effective: frequency IIR band-pass and notch filters @@ -799,7 +1629,8 @@ class CircularShiftingClassifier(components.ProcessingMethod): Basically, it computes a template for each sequence. """ - def __init__(self, correct_raster_latencies=False, art_rej=None, **kwargs): + def __init__(self, correct_raster_latencies=False, art_rej=None, + extra_epoch_samples=21, **kwargs): """ Class constructor """ super().__init__(fit_dataset=['templates', 'cca_by_seq']) @@ -807,6 +1638,7 @@ def __init__(self, correct_raster_latencies=False, art_rej=None, **kwargs): self.art_rej = art_rej self.correct_raster_latencies = correct_raster_latencies + self.extra_epoch_samples = extra_epoch_samples def _assert_consistency(self, dataset: CVEPSpellerDataset): # TODO: this function is not necessary. Use CVEPSpellerDataset @@ -876,8 +1708,8 @@ def _assert_consistency(self, dataset: CVEPSpellerDataset): return fs, fps_resolution, len_seq, unique_seqs_by_run, is_filter_bank - def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0, - show_progress_bar=True): + def fit_dataset(self, dataset: CVEPSpellerDataset, + roll_targets=False, show_progress_bar=True): # Error checking fs, fps_resolution, len_seq, unique_seqs_by_run, is_filter_bank = \ @@ -905,6 +1737,15 @@ def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0, # Get unique sequences for this run unique_seqs = unique_seqs_by_run[rec_idx] + if roll_targets: + new_unique_seqs = dict() + for key, value in unique_seqs.items(): + c_ = rec_exp.command_idx[value[0]] + c_lag_ = rec_exp.commands_info[0][str(int(c_))]['lag'] + c_seq_ = rec_exp.commands_info[0][str(int(c_))]['sequence'] + new_key = np.roll(c_seq_, c_lag_) + new_unique_seqs[tuple(new_key)] = value + unique_seqs = new_unique_seqs # For each filter bank for filter_idx, signal in enumerate(rec_sig.signal): @@ -917,6 +1758,16 @@ def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0, w_epoch_t=[0, len_epoch_ms], w_baseline_t=None, norm=None) + + # Roll targets if training was not made with the 0 lag command + if roll_targets: + for idx_, c_ in enumerate(rec_exp.command_idx): + # TODO: nested matrices + c_lag_ = rec_exp.commands_info[0][str(int(c_))]['lag'] + lag_samples = int(np.round(c_lag_ / fps_resolution * fs)) + # Revert the lag in the epoch + epochs[idx_, :, :] = np.roll( + epochs[idx_, :, :], lag_samples, axis=0) # Organize epochs by sequence for seq_, ep_idxs_ in unique_seqs.items(): @@ -939,9 +1790,9 @@ def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0, if show_progress_bar: pbar.update(1) + # Precompute nearest channels for online artifact rejection sorted_dist_ch = None - if std_epoch_rejection is not None: - # Precompute nearest channels for online artifact rejection + if self.art_rej is not None: sorted_dist_ch = rec_sig.channel_set.sort_nearest_channels() # New bar @@ -960,7 +1811,7 @@ def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0, for filter_idx in range(len(epochs_by_seq[seq_])): # Offline artifact rejection - if std_epoch_rejection is not None: + if self.art_rej is not None: epochs_std = np.std(epochs_by_seq[seq_][filter_idx], axis=1) # STD per samples ch_std = np.std(epochs_std, axis=0) # Variation of epochs @@ -970,11 +1821,11 @@ def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0, epoch_to_keep[:, i] = ( (epochs_std[:, i] < ( np.median(epochs_std[:, i]) + - std_epoch_rejection * ch_std[ + self.art_rej * ch_std[ i])) & (epochs_std[:, i] > ( np.median(epochs_std[:, i]) - - std_epoch_rejection * ch_std[ + self.art_rej * ch_std[ i])) ) # Keep only epochs that are suitable for all channels @@ -1033,7 +1884,7 @@ def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0, 'fps_resolution': fps_resolution, 'len_epoch_ms': len_epoch_ms, 'len_epoch_sam': len_epoch_sam, - 'std_epoch_rejection': std_epoch_rejection, + 'std_epoch_rejection': self.art_rej, 'no_discarded_epochs': discarded_epochs, 'no_total_epochs': total_epochs, 'sorted_dist_ch': sorted_dist_ch @@ -1044,7 +1895,8 @@ def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0, return self.fitted def _is_predict_feasible(self, dataset): - l_ms = self.fitted['len_epoch_ms'] + l_ms = self.fitted['len_epoch_ms'] + \ + np.ceil(self.extra_epoch_samples/dataset.fs) for rec in dataset.recordings: rec_sig = getattr(rec, dataset.biosignal_att_key) rec_exp = getattr(rec, dataset.experiment_att_key) @@ -1057,7 +1909,8 @@ def _is_predict_feasible(self, dataset): return True def _is_predict_feasible_signal(self, times, onsets, fs): - l_ms = self.fitted['len_epoch_ms'] + l_ms = self.fitted['len_epoch_ms'] + \ + np.ceil(self.extra_epoch_samples/fs) feasible = ep.check_epochs_feasibility(timestamps=times, onsets=onsets, fs=fs, @@ -1118,7 +1971,9 @@ def predict(self, times, signal, trial_idx, exp_data, sig_data): # For each number of cycles pred_item_by_no_cycles = [] - no_cycles = np.max(exp_data.cycle_idx).astype(int) + 1 + _exp_cycle_idx = np.array(exp_data.cycle_idx) + no_cycles = np.max(_exp_cycle_idx[np.array(exp_data.trial_idx) == + trial_idx]).astype(int) + 1 for nc in range(no_cycles): # Identify what are the epochs that must be processed idx = (np.array(exp_data.trial_idx) == trial_idx) & \ @@ -1336,12 +2191,19 @@ def is_shifted_version(stored_seqs, seq_to_check): # interprets all dictionary keys as strings. # Add the command index to its associated sequence - if tuple(curr_seq_) not in sequences: + if len(sequences) == 0: sequences[tuple(curr_seq_)] = [idx] - else: + elif tuple(curr_seq_) in sequences: + # Already there, add the cycle idx sequences[tuple(curr_seq_)].append(idx) - - # todo: check that sequences are not shifted versions of themselves?? + else: + # If not there, first check that it is not a shifted version + # of a present sequences + orig_seq = is_shifted_version(list(sequences.keys()), curr_seq_) + if orig_seq is not None: + sequences[tuple(orig_seq)].append(idx) + else: + sequences[tuple(curr_seq_)] = [idx] except Exception as e: print(e) return sequences diff --git a/medusa/plots/timeplot.py b/medusa/plots/timeplot.py index aa44076..fd5a533 100644 --- a/medusa/plots/timeplot.py +++ b/medusa/plots/timeplot.py @@ -76,7 +76,7 @@ def __plot_events_lines(ax, events_dict, min_val, max_val, display_times): # Create legend above the plot if previous_conditions is not None: - previous_handles = ax.legend_.legendHandles + previous_handles = ax.legend_.legend_handles for legend_line in list(legend_lines.values()): previous_handles.append(legend_line) previous_conditions.append(legend_line._label) @@ -200,7 +200,7 @@ def __reshape_signal(epochs): return epoch_c -def time_plot(signal, fs=1.0, ch_labels=None, time_to_show=None, +def time_plot(signal, times=None, fs=1.0, ch_labels=None, time_to_show=None, ch_to_show=None, ch_offset=None, color='k', conditions_dict=None, events_dict=None, show_epoch_lines=True, fig=None, axes=None): @@ -210,6 +210,8 @@ def time_plot(signal, fs=1.0, ch_labels=None, time_to_show=None, signal: numpy ndarray Signal with shape of [n_epochs,n_samples, n_channels] or [n_samples, n_channels] + times: numpy ndarray + Timestamps of each sample of the signal with shape [n_samples] fs: float Sampling rate. Value 1 as default ch_labels: list of strings or None @@ -317,8 +319,11 @@ def time_plot(signal, fs=1.0, ch_labels=None, time_to_show=None, max_val, min_val = epoch_c.max(), epoch_c.min() # Define times vector - display_times = np.linspace(0, int(epoch_c.shape[0] / fs), - epoch_c.shape[0]) + if times is None: + display_times = np.linspace(0, (epoch_c.shape[0] - 1) / fs, + epoch_c.shape[0]) + else: + display_times = times # Initialize plot if fig is None: @@ -507,4 +512,4 @@ def on_key(event): # Initialize TimePlot instance time_plot(signal=signal,fs=fs,ch_labels=l_cha,time_to_show=None, ch_to_show=None,ch_offset=None,conditions_dict=c_dict, - events_dict=e_dict,show_epoch_lines=True,show=True) + events_dict=e_dict,show_epoch_lines=True) \ No newline at end of file diff --git a/setup.py b/setup.py index 394b7d4..9bc4781 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='medusa-kernel', packages=find_packages(), - version='1.2.5', + version='1.3.0', keywords=['Signal', 'Biosignal', 'EEG', 'BCI'], url='https://medusabci.com/', author='Eduardo Santamaría-Vázquez, '