From 4652a8095586715a290bf0eadbcd59d97194adcb Mon Sep 17 00:00:00 2001 From: Koncopd Date: Mon, 11 Dec 2023 09:51:04 +0100 Subject: [PATCH] add map_location to scpoli load_query_data --- scarches/models/scpoli/scpoli_model.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/scarches/models/scpoli/scpoli_model.py b/scarches/models/scpoli/scpoli_model.py index b5bab562..9d5a577f 100644 --- a/scarches/models/scpoli/scpoli_model.py +++ b/scarches/models/scpoli/scpoli_model.py @@ -714,7 +714,8 @@ def load_query_data( freeze: bool = True, freeze_expression: bool = True, remove_dropout: bool = True, - return_new_conditions: bool = False + return_new_conditions: bool = False, + map_location = None, ): """Transfer Learning function for new data. Uses old trained model and expands it for new conditions. @@ -724,20 +725,26 @@ def load_query_data( Query anndata object. reference_model SCPOLI model to expand or a path to SCPOLI model folder. + labeled_indices: List + List of integers with the indices of the labeled cells. + unknown_ct_names: List + List of strings with the names of cell clusters to be ignored for prototypes computation. freeze: Boolean If 'True' freezes every part of the network except the first layers of encoder/decoder. freeze_expression: Boolean If 'True' freeze every weight in first layers except the condition weights. remove_dropout: Boolean If 'True' remove Dropout for Transfer Learning. - + map_location + map_location to remap storage locations (as in '.load') of 'reference_model'. + Only taken into account if 'reference_model' is a path to a model on disk. Returns ------- new_model: scPoli New SCPOLI model to train on query data. """ if isinstance(reference_model, str): - attr_dict, model_state_dict, var_names = cls._load_params(reference_model) + attr_dict, model_state_dict, var_names = cls._load_params(reference_model, map_location) adata = _validate_var_names(adata, var_names) else: attr_dict = deepcopy(reference_model._get_public_attributes())