From ffdccdd86849da531d011829c95dce9693f0e11c Mon Sep 17 00:00:00 2001 From: Amir Ali Moinfar Date: Fri, 27 Oct 2023 16:40:24 +0200 Subject: [PATCH 1/2] Improve train/test split performance when cell_type is available --- scarches/trainers/scpoli/_utils.py | 32 ++++++++++++++---------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/scarches/trainers/scpoli/_utils.py b/scarches/trainers/scpoli/_utils.py index 70a77910..c62a8027 100644 --- a/scarches/trainers/scpoli/_utils.py +++ b/scarches/trainers/scpoli/_utils.py @@ -127,28 +127,26 @@ def train_test_split(adata, train_frac=0.85, condition_keys=None, cell_type_key= labeled_idx = indices[labeled_array == 1] unlabeled_idx = indices[labeled_array == 0] - train_labeled_idx = [] - val_labeled_idx = [] - train_unlabeled_idx = [] - val_unlabeled_idx = [] + train_labeled_idx = np.array([], dtype=int) + val_labeled_idx = np.array([], dtype=int) + train_unlabeled_idx = np.array([], dtype=int) + val_unlabeled_idx = np.array([], dtype=int) if len(labeled_idx) > 0: - cell_types = adata[labeled_idx].obs[cell_type_key].unique().tolist() - for cell_type in cell_types: - ct_idx = labeled_idx[adata[labeled_idx].obs[cell_type_key] == cell_type] - n_train_samples = int(np.ceil(train_frac * len(ct_idx))) - np.random.shuffle(ct_idx) - train_labeled_idx.append(ct_idx[:n_train_samples]) - val_labeled_idx.append(ct_idx[n_train_samples:]) + cell_type_info = adata[labeled_idx].obs[[cell_type_key]].copy() + cell_type_info['random'] = np.random.rand(len(cell_type_info.index)) + cell_type_info['count_in_ct'] = cell_type_info.groupby(cell_type_key, observed=True)['random'].transform('count') + cell_type_info['rank_in_ct'] = cell_type_info.groupby(cell_type_key, observed=True)['random'].rank(method="first") - 1 + cell_type_info['train'] = cell_type_info['count_in_ct'] * train_frac > cell_type_info['rank_in_ct'] + train_labeled_idx = labeled_idx[cell_type_info['train']] + val_labeled_idx = labeled_idx[~cell_type_info['train']] if len(unlabeled_idx) > 0: n_train_samples = int(np.ceil(train_frac * len(unlabeled_idx))) - train_unlabeled_idx.append(unlabeled_idx[:n_train_samples]) - val_unlabeled_idx.append(unlabeled_idx[n_train_samples:]) - train_idx = train_labeled_idx + train_unlabeled_idx - val_idx = val_labeled_idx + val_unlabeled_idx + train_unlabeled_idx = unlabeled_idx[:n_train_samples] + val_unlabeled_idx = unlabeled_idx[n_train_samples:] - train_idx = np.concatenate(train_idx) - val_idx = np.concatenate(val_idx) + train_idx = np.concatenate([train_labeled_idx, train_unlabeled_idx]) + val_idx = np.concatenate([val_labeled_idx, val_unlabeled_idx]) elif condition_keys is not None: train_idx = [] From 871300332046182e9f27cda791fe7be050d63cfb Mon Sep 17 00:00:00 2001 From: Amir Ali Moinfar Date: Fri, 27 Oct 2023 16:40:59 +0200 Subject: [PATCH 2/2] Improve dataloader performance with batch indexing --- scarches/dataset/scpoli/anndata.py | 20 ++++++++++++-------- scarches/trainers/scpoli/_utils.py | 3 +++ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/scarches/dataset/scpoli/anndata.py b/scarches/dataset/scpoli/anndata.py index e49010f6..96261b22 100644 --- a/scarches/dataset/scpoli/anndata.py +++ b/scarches/dataset/scpoli/anndata.py @@ -76,27 +76,31 @@ def __init__(self, self.cell_types = np.stack(self.cell_types).T self.cell_types = torch.tensor(self.cell_types, dtype=torch.long) - def __getitem__(self, index): + def __getitems__(self, indices): + # Make sure this function supports both single-element and list input outputs = dict() if self._is_sparse: - x = torch.tensor(np.squeeze(self.data[index].toarray()), dtype=torch.float32) + x = torch.tensor(np.squeeze(self.data[indices].toarray()), dtype=torch.float32) else: - x = self.data[index] + x = self.data[indices] outputs["x"] = x - outputs["labeled"] = self.labeled_vector[index] - outputs["sizefactor"] = self.size_factors[index] + outputs["labeled"] = self.labeled_vector[indices] + outputs["sizefactor"] = self.size_factors[indices] if self.condition_keys: - outputs["batch"] = self.conditions[index, :] - outputs["combined_batch"] = self.conditions_combined[index] + outputs["batch"] = self.conditions[indices, :] + outputs["combined_batch"] = self.conditions_combined[indices] if self.cell_type_keys: - outputs["celltypes"] = self.cell_types[index, :] + outputs["celltypes"] = self.cell_types[indices, :] return outputs + def __getitem__(self, index): + return self.__getitems__(index) + def __len__(self): return self.data.shape[0] diff --git a/scarches/trainers/scpoli/_utils.py b/scarches/trainers/scpoli/_utils.py index c62a8027..05e5816e 100644 --- a/scarches/trainers/scpoli/_utils.py +++ b/scarches/trainers/scpoli/_utils.py @@ -67,6 +67,9 @@ def _print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, leng sys.stdout.flush() def custom_collate(batch): + if isinstance(batch, dict): + return batch + r"""Puts each data field into a tensor with outer dimension batch size""" np_str_obj_array_pattern = re.compile(r'[SaUO]') default_collate_err_msg_format = (