Skip to content

Commit

Permalink
prints for identifying mem issues
Browse files Browse the repository at this point in the history
  • Loading branch information
LouisK92 committed Nov 29, 2024
1 parent 1ec08e6 commit 2f6c1ad
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 5 deletions.
48 changes: 45 additions & 3 deletions spapros/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from spapros.util.mp_util import Signal, SigQueue, _get_n_cores, parallelize
from spapros.util.util import NestedProgress, init_progress


from spapros.util.util import print_memory_usage, get_size
# helper for type checking:


Expand Down Expand Up @@ -1169,6 +1171,7 @@ def plot_marker_corr(self, **kwargs):
Any keyword argument from :func:`.marker_correlation`.
Example:
.. code-block:: python
Expand Down Expand Up @@ -1605,6 +1608,10 @@ def train_ct_tree_helper(
TODO: Write docstring
"""

pid = os.getpid()
print_memory_usage(f"Process {pid}: Starting training cell types")

ct_trees = {}
for ct in celltypes:
ct_trees[ct] = tree.DecisionTreeClassifier(
Expand All @@ -1615,9 +1622,13 @@ def train_ct_tree_helper(
elif np.sum(masks[ct]) > 0:
ct_trees[ct] = ct_trees[ct].fit(X_train[masks[ct], :], y_train[ct][masks[ct]])

print_memory_usage(f"Process {pid}: After training cell type {ct}")

if queue is not None:
queue.put(Signal.UPDATE)

print_memory_usage(f"Process {pid}: After training all cell types")

if queue is not None:
queue.put(Signal.FINISH)

Expand Down Expand Up @@ -1816,6 +1827,8 @@ def single_forest_classifications(
# date)
# TODO: Add progress bars to trees, and maybe change verbose to verbosity levels

print_memory_usage("Starting single_forest_classifications")

# if verbose:
# try:
# from tqdm.notebook import tqdm
Expand Down Expand Up @@ -1892,6 +1905,11 @@ def single_forest_classifications(
ct_trees: Dict[str, list] = {ct: [] for ct in celltypes}
np.random.seed(seed=seed)
seeds = np.random.choice(100000, n_trees, replace=False)


print_memory_usage("Before training trees")
tree_counter = 0

# Compute trees (for each tree index we parallelize over celltypes)
# for i in tqdm(range(n_trees), desc="Train trees") if tqdm else range(n_trees):
if progress and verbose:
Expand All @@ -1900,6 +1918,11 @@ def single_forest_classifications(
X_train, y_train, cts_train = uniform_samples(
a, ct_key, set_key="train_set", subsample=subsample, seed=seeds[i], celltypes=ref_celltypes
)

if verbose:
tree_counter += 1
print(f"Training tree iteration: {tree_counter}/{n_trees}")

if ct_spec_ref is not None:
masks: Optional[Dict[str, np.ndarray[Any, np.dtype[np.bool_]]]] = get_reference_masks(
cts_train, ct_spec_ref
Expand All @@ -1923,23 +1946,40 @@ def single_forest_classifications(
del y_train
del cts_train
gc.collect()

print_memory_usage(f"After gc.collect() iteration {i}")

print_memory_usage("After training all trees")

print_memory_usage("Before calculating feature importances")

# Get feature importances
importances = {
ct: pd.DataFrame(index=a.var.index, columns=[str(i) for i in range(n_trees)], dtype="float64")
for ct in celltypes
}
for i in range(n_trees):
if verbose:
print(f"Calculating feature importances for tree {i+1}/{n_trees}")

for ct in celltypes:
if (masks_test is None) or (np.sum(masks_test[ct]) > 0):
importances[ct][str(i)] = ct_trees[ct][i].feature_importances_

# Garbage collection after each tree iteration
gc.collect()
print_memory_usage(f"After feature importance calculation iteration {i}")

print_memory_usage("Before evaluating trees")

# Evaluate trees (we parallelize over tree indices)
summary_metric, ct_specific_metric = parallelize(
callback=eval_ct_tree_helper,
collection=[i for i in range(n_trees)],
n_jobs=n_jobs,
backend=backend,
extractor=pool_eval_ct_tree_helper,
show_progress_bar=False, # =verbose,
show_progress_bar=False,
desc="Evaluate trees",
)(
celltypes=celltypes,
Expand All @@ -1950,11 +1990,13 @@ def single_forest_classifications(
cts_test=cts_test,
masks=masks_test,
)
# garbage collection

# Aggressive garbage collection after evaluation
del X_test
del y_test
del cts_test
gc.collect()
print_memory_usage("After tree evaluation and garbage collection")

# Sort results
if sort_by_tree_performance:
Expand Down Expand Up @@ -2161,7 +2203,7 @@ def forest_classifications(
adata:
An already preprocessed annotated data matrix. Typically we use log normalised data.
selection:
Trees are trained on genes of the list or genes defined in the bool column ``selection[selection]``.
Trees are trained on genes of the list or genes defined in the bool column ``selection[['selection']``.
max_n_forests:
Number of best trees considered as a tree group. Including the primary tree.
verbosity:
Expand Down
14 changes: 12 additions & 2 deletions spapros/selection/selection_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
filter_marker_dict_by_shared_genes,
)


from spapros.util.util import print_memory_usage, get_size
# from tqdm.autonotebook import tqdm


Expand Down Expand Up @@ -509,7 +511,7 @@ def select_probeset(self) -> None:
"""
assert isinstance(self.progress, RichCast)
with self.progress:

print_memory_usage("Starting probeset selection")
if self.verbosity > 0:
selection_task = self.progress.add_task(
description="SPAPROS PROBESET SELECTION:", only_text=True, header=True, total=0
Expand All @@ -519,42 +521,50 @@ def select_probeset(self) -> None:
t = time.time()
if self.n_pca_genes and (self.n_pca_genes > 0):
self._pca_selection()
print_memory_usage("After PCA selection")
self._save_time_measurement("PCA_selection", t)

# DE forests
# DE forests
t = time.time()
self._forest_DE_baseline_selection()
print_memory_usage("After DE forest baseline selection")
self._save_time_measurement("DE_forest_selection", t)

# PCA forests (including optimization based on DE forests), or just DE forests if no PCA genes were selected
t = time.time()
if self.n_pca_genes and (self.n_pca_genes > 0):
self._forest_selection()
print_memory_usage("After PCA forest selection")
else:
self._set_DE_baseline_forest_to_final_forest()
print_memory_usage("After setting DE baseline forest as final")
self._save_time_measurement("PCA_forest_selection", t)

# Add markers from curated list
t = time.time()
if self.marker_list:
self._marker_selection()
print_memory_usage("After marker selection")
self._save_time_measurement("marker_selection", t)

# Compile probe set
t = time.time()
self.probeset = self._compile_probeset_list()
print_memory_usage("After compiling probeset")
self._save_time_measurement("compile_probeset", t)
self.selection["final"] = self.probeset

# Save attribute genes_of_primary_trees
self.genes_of_primary_trees = self._get_genes_of_primary_trees()
print_memory_usage("After getting primary tree genes")

if self.verbosity > 0:
self.progress.advance(selection_task)
self.progress.add_task(description="FINISHED\n", footer=True, only_text=True, total=0)

if self.save_dir and (not os.path.exists(self.probeset_path)):
self.probeset.to_csv(self.probeset_path)
print_memory_usage("After saving probeset")

def _pca_selection(self) -> None:
"""Select genes based on pca loadings."""
Expand Down
35 changes: 35 additions & 0 deletions spapros/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,41 @@
# Data Utils #
##############

import psutil
import os

def print_memory_usage(message="", log_file="memory_usage.log"):
"""Print and log memory usage of the current process.
Args:
message: Optional message to prepend to the memory usage info
log_file: Path to the log file where memory usage will be recorded
"""
process = psutil.Process(os.getpid())
mem = process.memory_info().rss / (1024 * 1024) # Convert bytes to MB
output = f"{message} Memory usage: {mem:.2f} MB"

## Print to console
#print(output)

# Append to log file
with open(log_file, 'a') as f:
# Add timestamp to log entry
timestamp = pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')
f.write(f"[{timestamp}],{process.pid},{mem:.2f},{output}\n")

import sys

def get_size(obj, unit='MB'):
size_bytes = sys.getsizeof(obj)
if unit == 'MB':
return size_bytes / (1024 * 1024)
elif unit == 'KB':
return size_bytes / 1024
return size_bytes




def get_processed_pbmc_data(n_hvg: int = 1000):
"""Get log normalised pbmc AnnData with selection and evaluation relevant quantities
Expand Down

0 comments on commit 2f6c1ad

Please sign in to comment.