Skip to content

Commit

Permalink
add map_location to scpoli load_query_data
Browse files Browse the repository at this point in the history
  • Loading branch information
Koncopd committed Dec 11, 2023
1 parent 6a3bacc commit 4652a80
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions scarches/models/scpoli/scpoli_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,8 @@ def load_query_data(
freeze: bool = True,
freeze_expression: bool = True,
remove_dropout: bool = True,
return_new_conditions: bool = False
return_new_conditions: bool = False,
map_location = None,
):
"""Transfer Learning function for new data. Uses old trained model and expands it for new conditions.
Expand All @@ -724,20 +725,26 @@ def load_query_data(
Query anndata object.
reference_model
SCPOLI model to expand or a path to SCPOLI model folder.
labeled_indices: List
List of integers with the indices of the labeled cells.
unknown_ct_names: List
List of strings with the names of cell clusters to be ignored for prototypes computation.
freeze: Boolean
If 'True' freezes every part of the network except the first layers of encoder/decoder.
freeze_expression: Boolean
If 'True' freeze every weight in first layers except the condition weights.
remove_dropout: Boolean
If 'True' remove Dropout for Transfer Learning.
map_location
map_location to remap storage locations (as in '.load') of 'reference_model'.
Only taken into account if 'reference_model' is a path to a model on disk.
Returns
-------
new_model: scPoli
New SCPOLI model to train on query data.
"""
if isinstance(reference_model, str):
attr_dict, model_state_dict, var_names = cls._load_params(reference_model)
attr_dict, model_state_dict, var_names = cls._load_params(reference_model, map_location)
adata = _validate_var_names(adata, var_names)
else:
attr_dict = deepcopy(reference_model._get_public_attributes())
Expand Down

0 comments on commit 4652a80

Please sign in to comment.