Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/developers' into developers
Browse files Browse the repository at this point in the history
  • Loading branch information
esantamariavazquez committed Nov 8, 2023
2 parents b809b02 + cad33bf commit 9543c12
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 545 deletions.
49 changes: 42 additions & 7 deletions medusa/bci/cvep_spellers.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def check_predict_feasibility_signal(self, times, onsets, fs):
return self.get_inst('clf_method')._is_predict_feasible_signal(
times, onsets, fs)

def fit_dataset(self, dataset, **kwargs):
def fit_dataset(self, dataset, roll_targets=False, **kwargs):
# Safe copy
data = copy.deepcopy(dataset)

Expand All @@ -439,7 +439,8 @@ def fit_dataset(self, dataset, **kwargs):
fitted_info = self.get_inst('clf_method').fit_dataset(
dataset=data,
std_epoch_rejection=3.0,
show_progress_bar=True
show_progress_bar=True,
roll_targets=roll_targets
)

return fitted_info
Expand Down Expand Up @@ -875,7 +876,7 @@ def _assert_consistency(self, dataset: CVEPSpellerDataset):
return fs, fps_resolution, len_seq, unique_seqs_by_run, is_filter_bank

def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0,
show_progress_bar=True):
roll_targets=False, show_progress_bar=True):

# Error checking
fs, fps_resolution, len_seq, unique_seqs_by_run, is_filter_bank = \
Expand Down Expand Up @@ -903,6 +904,15 @@ def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0,

# Get unique sequences for this run
unique_seqs = unique_seqs_by_run[rec_idx]
if roll_targets:
new_unique_seqs = dict()
for key, value in unique_seqs.items():
c_ = rec_exp.command_idx[value[0]]
c_lag_ = rec_exp.commands_info[0][str(int(c_))]['lag']
c_seq_ = rec_exp.commands_info[0][str(int(c_))]['sequence']
new_key = np.roll(c_seq_, c_lag_)
new_unique_seqs[tuple(new_key)] = value
unique_seqs = new_unique_seqs

# For each filter bank
for filter_idx, signal in enumerate(rec_sig.signal):
Expand All @@ -915,6 +925,16 @@ def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0,
w_epoch_t=[0, len_epoch_ms],
w_baseline_t=None,
norm=None)

# Roll targets if training was not made with the 0 lag command
if roll_targets:
for idx_, c_ in enumerate(rec_exp.command_idx):
# TODO: nested matrices
c_lag_ = rec_exp.commands_info[0][str(int(c_))]['lag']
lag_samples = int(np.round(c_lag_ / fps_resolution * fs))
# Revert the lag in the epoch
epochs[idx_, :, :] = np.roll(
epochs[idx_, :, :], lag_samples, axis=0)

# Organize epochs by sequence
for seq_, ep_idxs_ in unique_seqs.items():
Expand Down Expand Up @@ -1307,6 +1327,14 @@ def get_unique_sequences_from_targets(experiment: CVEPSpellerData):
""" Function that returns the unique sequences of all targets.
return
"""
def is_shifted_version(stored_seqs, seq_to_check):
for s in stored_seqs:
if len(s) != len(seq_to_check):
continue
for j in range(len(seq_to_check)):
if np.all(np.array(s) == np.roll(seq_to_check, -j)):
return s
return None
sequences = dict()
try:
# todo: command_idx, unit_idx y demas lo tiene que hacer medusa y no unity
Expand All @@ -1324,12 +1352,19 @@ def get_unique_sequences_from_targets(experiment: CVEPSpellerData):
# interprets all dictionary keys as strings.

# Add the command index to its associated sequence
if tuple(curr_seq_) not in sequences:
if len(sequences) == 0:
sequences[tuple(curr_seq_)] = [idx]
else:
elif tuple(curr_seq_) in sequences:
# Already there, add the cycle idx
sequences[tuple(curr_seq_)].append(idx)

# todo: check that sequences are not shifted versions of themselves??
else:
# If not there, first check that it is not a shifted version
# of a present sequences
orig_seq = is_shifted_version(list(sequences.keys()), curr_seq_)
if orig_seq is not None:
sequences[tuple(orig_seq)].append(idx)
else:
sequences[tuple(curr_seq_)] = [idx]
except Exception as e:
print(e)
return sequences
Expand Down
Loading

0 comments on commit 9543c12

Please sign in to comment.