Skip to content

Commit

Permalink
Add cell cycle score baseline (#706)
Browse files Browse the repository at this point in the history
* add cc_score baseline

* document

* Make sure method didn't remove uns

* Combat tramples uns

* Revert

* Scale and hvg trample uns

* scanorama clears uns

* mnn tramples uns

* just copy uns

* just copy uns

* don't set X_emb if missing; it shouldn't ever be missing

* use true features as embedding

* compute PCA per batch

* Set code version
  • Loading branch information
scottgigante-immunai authored Dec 3, 2022
1 parent 7361925 commit 7ffc855
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Datasets should contain the following attributes:
* `adata.obsm['X_uni']` with a pre-integration embedding (PCA)
* `adata.layers['log_normalized']` with log-normalized data
* `adata.X` with log-normalized data
* `adata.uns["organism"]` with either `"mouse"` or `"human"`

Methods should assign output to `adata.obsm['X_emb']`.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ def check_dataset(adata):
assert "batch" in adata.obs
assert "labels" in adata.obs
assert "log_normalized" in adata.layers
assert "organism" in adata.uns
assert adata.uns["organism"] in ["mouse", "human"]

return True


def check_method(adata, is_baseline=False):
"""Check that method output fits expected API."""
assert "X_emb" in adata.obsm
# check organism was not removed
assert "organism" in adata.uns
return True


Expand All @@ -27,6 +31,7 @@ def sample_dataset():
import scanpy as sc

adata = load_sample_data()
adata.uns["organism"] = "human"

adata.var.index = adata.var.gene_short_name.astype(str)
sc.pp.normalize_total(adata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .baseline import celltype_random_embedding
from .baseline import celltype_random_integration
from .baseline import no_integration
from .baseline import no_integration_batch
from .baseline import random_integration
from .scalex import scalex_full
from .scalex import scalex_hvg
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from ...batch_integration_graph.methods.baseline import _random_embedding
from ...batch_integration_graph.methods.baseline import _randomize_features

import numpy as np
import scanpy as sc


@method(
method_name="No Integration",
Expand Down Expand Up @@ -76,3 +79,32 @@ def batch_random_integration(adata, test=False):
)
adata.uns["method_code_version"] = check_version("openproblems")
return adata


@method(
method_name="No Integration by Batch",
paper_name="No Integration by Batch (baseline)",
paper_url="https://openproblems.bio",
paper_year=2022,
code_url="https://github.com/openproblems-bio/openproblems",
is_baseline=True,
)
def no_integration_batch(adata, test=False):
"""Compute PCA independently on each batch
See https://github.com/theislab/scib/issues/351
"""
adata.obsm["X_emb"] = np.zeros((adata.shape[0], 50), dtype=float)
for batch in adata.obs["batch"].unique():
batch_idx = adata.obs["batch"] == batch
n_comps = min(50, np.sum(batch_idx))
solver = "full" if n_comps == np.sum(batch_idx) else "arpack"
adata.obsm["X_emb"][batch_idx, :n_comps] = sc.tl.pca(
adata[batch_idx],
n_comps=n_comps,
use_highly_variable=False,
svd_solver=solver,
copy=True,
).obsm["X_pca"]
adata.uns["method_code_version"] = check_version("openproblems")
return adata
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
def _get_split(adata):
uni = adata
uni.obsm["X_pca"] = uni.obsm["X_uni_pca"]

if "X_emb" not in adata.obsm:
adata.obsm["X_emb"] = adata.obsm["X_pca"]

return (uni, adata)
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
@metric(
metric_name="Cell Cycle Score",
maximize=True,
image="openproblems-python-batch-integration", # only if required
image="openproblems-python-batch-integration",
)
def cc_score(adata, test=False):
from ._utils import _get_split
from scib.metrics import cell_cycle

try:
cc = cell_cycle(*_get_split(adata), "batch", embed="X_emb", organism="human")
cc = cell_cycle(
*_get_split(adata), "batch", embed="X_emb", organism=adata.uns["organism"]
)

except ValueError:
cc = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def immune_batch(test=False):
import scanpy as sc

adata = load_immune(test)
adata.uns["organism"] = "human"
adata.obs["labels"] = adata.obs["final_annotation"]

sc.pp.filter_genes(adata, min_counts=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def pancreas_batch(test=False):
import scanpy as sc

adata = load_pancreas(test)
adata.uns["organism"] = "human"
adata.obs["labels"] = adata.obs["celltype"]
adata.obs["batch"] = adata.obs["tech"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ def hvg_batch(adata, batch_key, target_genes, adataOut):
if adata.n_vars < 2000:
return adata
else:
# uns and var get trampled
uns = adata.uns.copy()
var = adata.var.copy()
adata = hvg_batch(
adata,
Expand All @@ -13,13 +15,17 @@ def hvg_batch(adata, batch_key, target_genes, adataOut):
adataOut=adataOut,
)
adata.var = var.loc[adata.var.index]
adata.uns = uns
return adata


def scale_batch(adata, batch_key):
from scib.preprocessing import scale_batch

# uns and var get trampled
uns = adata.uns.copy()
var = adata.var.copy()
adata = scale_batch(adata, batch_key)
adata.var = var.loc[adata.var_names]
adata.uns = uns
return adata
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def _mnn(adata):
from scib.integration import runMNN
from scib.preprocessing import reduce_data

# mnn clears adata.uns
uns = adata.uns
adata = runMNN(adata, "batch")
adata.uns = uns
reduce_data(adata, umap=False)
adata.obsm["X_emb"] = adata.obsm["X_pca"]
adata.uns["method_code_version"] = check_version("mnnpy")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ def _scanorama(adata, use_rep, pca):
from scib.integration import scanorama
from scib.preprocessing import reduce_data

# scanorama clears adata.layers
# scanorama clears adata.layers and uns
layers = adata.layers
uns = adata.uns
adata = scanorama(adata, "batch")
adata.layers = layers
adata.uns = uns
reduce_data(adata, umap=False, use_rep=use_rep, pca=pca)
adata.uns["method_code_version"] = check_version("scanorama")
return adata
Expand Down

0 comments on commit 7ffc855

Please sign in to comment.