Skip to content

Commit

Permalink
reformating: v4
Browse files Browse the repository at this point in the history
  • Loading branch information
janursa committed Feb 2, 2025
1 parent 7cd9b07 commit 0fc738e
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 351 deletions.
1 change: 1 addition & 0 deletions src/methods/single_omics/grnboost2/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tqdm import tqdm
import subprocess
import sys
import anndata as ad

## VIASH START
par = {
Expand Down
1 change: 1 addition & 0 deletions src/methods/single_omics/scenic/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import requests
import scipy.sparse as sp
import sys
import anndata as ad



Expand Down
4 changes: 4 additions & 0 deletions src/methods/single_omics/scgpt/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
viash run src/methods/single_omics/scgpt/config.vsh.yaml -- \
--rna resources_test/inference_datasets/op_rna.h5ad \
--tf_all resources/prior/tf_all.csv \
--prediction output/prediction.h5ad
165 changes: 34 additions & 131 deletions src/methods/single_omics/scgpt/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,146 +170,49 @@ def monitor_memory():
import scanpy as sc
adata = sc.read(par['rna'])
adata.X = adata.X.todense()
adata.obs["celltype"] = adata.obs["cell_type"].astype("category")
adata.obs["celltype"] = adata.obs["cell_type"].astype("str")
adata.obs["str_batch"] = adata.obs["donor_id"].astype(str)
data_is_raw = True

adata.var["id_in_vocab"] = [1 if gene in vocab else -1 for gene in adata.var.index]
gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
adata = adata[:, adata.var["id_in_vocab"] >= 0]

# preprocessor = Preprocessor(
# use_key="counts", # the key in adata.layers to use as raw data
# filter_gene_by_counts=3, # step 1
# filter_cell_by_counts=False, # step 2
# normalize_total=1e4, # 3. whether to normalize the raw data and to what sum
# result_normed_key="X_normed", # the key in adata.layers to store the normalized data
# log1p=data_is_raw, # 4. whether to log1p the normalized data
# result_log1p_key="X_log1p",
# subset_hvg= False, # 5. whether to subset the raw data to highly variable genes
# hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
# binning=n_input_bins, # 6. whether to bin the raw data and to what number of bins
# result_binned_key="X_binned", # the key in adata.layers to store the binned data
# )
# preprocessor(adata, batch_key="str_batch")

# print('Subsetting to HVGs')
# sc.pp.highly_variable_genes(
# adata,
# layer=None,
# n_top_genes=n_hvg,
# batch_key="str_batch",
# flavor="seurat_v3" if data_is_raw else "cell_ranger",
# subset=False,
# )
# adata = adata[:, adata.var["highly_variable"]].copy()



input_layer_key = "X_norm"
all_counts = (
adata.layers[input_layer_key].A
if issparse(adata.layers[input_layer_key])
else adata.layers[input_layer_key]
)
genes = adata.var.index.tolist()
gene_ids = np.array(vocab(genes), dtype=int)

batch_size = batch_size
tokenized_all = tokenize_and_pad_batch(
all_counts,
gene_ids,
max_len=len(genes)+1,
vocab=vocab,
pad_token=pad_token,
pad_value=pad_value,
append_cls=True, # append <cls> token at the beginning
include_zero_gene=True,
preprocessor = Preprocessor(
use_key="X", # the key in adata.layers to use as raw data
filter_gene_by_counts=3, # step 1
filter_cell_by_counts=False, # step 2
normalize_total=1e4, # 3. whether to normalize the raw data and to what sum
result_normed_key="X_normed", # the key in adata.layers to store the normalized data
log1p=data_is_raw, # 4. whether to log1p the normalized data
result_log1p_key="X_log1p",
subset_hvg=False, # 5. whether to subset the raw data to highly variable genes
hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
binning=51, # 6. whether to bin the raw data and to what number of bins
result_binned_key="X_binned", # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key="batch")

# Retrieve the data-independent gene embeddings from scGPT
gene_ids = np.array([id for id in gene2idx.values()])
gene_embeddings = model.encoder(torch.tensor(gene_ids, dtype=torch.long).to(device))
gene_embeddings = gene_embeddings.detach().cpu().numpy()

all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
# Filter on the intersection between the Immune Human HVGs found in step 1.2 and scGPT's 30+K foundation model vocab
gene_embeddings = {gene: gene_embeddings[i] for i, gene in enumerate(gene2idx.keys()) if gene in adata.var.index.tolist()}
print('Retrieved gene embeddings for {} genes.'.format(len(gene_embeddings)))


src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
embed = GeneEmbedding(gene_embeddings)

condition_ids = np.array(adata.obs[par['condition']].tolist())

torch.cuda.empty_cache()
dict_sum_condition = {}
print('Extract gene gene links from attention layer')
model.eval()
monitor_memory()
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
M = all_gene_ids.size(1)
N = all_gene_ids.size(0)
device = next(model.parameters()).device
for i in tqdm(range(0, N, batch_size)):
batch_size = all_gene_ids[i : i + batch_size].size(0)
outputs = np.zeros((batch_size, M, M), dtype=np.float32)
# Replicate the operations in model forward pass
src_embs = model.encoder(torch.tensor(all_gene_ids[i : i + batch_size], dtype=torch.long).to(device))
# monitor_memory()
val_embs = model.value_encoder(torch.tensor(all_values[i : i + batch_size], dtype=torch.float).to(device))
total_embs = src_embs + val_embs
total_embs = model.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1)
# Send total_embs to attention layers for attention operations
# Retrieve the output from second to last layer
for layer in model.transformer_encoder.layers[:num_attn_layers]:
total_embs = layer(total_embs, src_key_padding_mask=src_key_padding_mask[i : i + batch_size].to(device))
# Send total_embs to the last layer in flash-attn
# https://github.com/HazyResearch/flash-attention/blob/1b18f1b7a133c20904c096b8b222a0916e1b3d37/flash_attn/flash_attention.py#L90
qkv = model.transformer_encoder.layers[num_attn_layers].self_attn.Wqkv(total_embs)
# Retrieve q, k, and v from flast-attn wrapper
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=8)
q = qkv[:, :, 0, :, :]
k = qkv[:, :, 1, :, :]
v = qkv[:, :, 2, :, :]
# https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a
# q = [batch, gene, n_heads, n_hid]
# k = [batch, gene, n_heads, n_hid]
# attn_scores = [batch, n_heads, gene, gene]
attn_scores = q.permute(0, 2, 1, 3) @ k.permute(0, 2, 3, 1)
# Rank normalization by row
attn_scores = attn_scores.reshape((-1, M))
order = torch.argsort(attn_scores, dim=1)
rank = torch.argsort(order, dim=1)
attn_scores = rank.reshape((-1, 8, M, M))/M
# Rank normalization by column
attn_scores = attn_scores.permute(0, 1, 3, 2).reshape((-1, M))
order = torch.argsort(attn_scores, dim=1)
rank = torch.argsort(order, dim=1)
attn_scores = (rank.reshape((-1, 8, M, M))/M).permute(0, 1, 3, 2)
# Average 8 attention heads
attn_scores = attn_scores.mean(1)
outputs = attn_scores.detach().cpu().numpy()
for index in range(batch_size):
# Keep track of sum per condition
c = condition_ids[i : i + batch_size][index]
if c not in dict_sum_condition:
dict_sum_condition[c] = np.zeros((M, M), dtype=np.float32)
else:
dict_sum_condition[c] += outputs[index, :, :]
print('Average across groups of cell types')
groups = adata.obs.groupby([par['condition']]).groups
dict_sum_condition_mean = dict_sum_condition.copy()
for i in groups.keys():
dict_sum_condition_mean[i] = dict_sum_condition_mean[i]/len(groups[i])
mean_grn = np.array(list(dict_sum_condition_mean.values())).mean(axis=0)
print('Subset only for TFs')
gene_vocab_idx = all_gene_ids[0].clone().detach().cpu().numpy()
gene_names = vocab.lookup_tokens(gene_vocab_idx)

print('Format as df, melt, and subset')
net = pd.DataFrame(mean_grn, columns=gene_names, index=gene_names)
net = net.iloc[1:, 1:]

tf_all = np.intersect1d(tf_all, gene_names)
net = net[tf_all]

net_melted = net.reset_index() # Move index to a column for melting
net_melted = pd.melt(net_melted, id_vars=net_melted.columns[0], var_name='target', value_name='weight')
net_melted.rename(columns={net_melted.columns[0]: 'source'}, inplace=True)
# Perform Louvain clustering with desired resolution; here we specify resolution=40
gdata = embed.get_adata(resolution=40)
# Retrieve the gene clusters
metagenes = embed.get_metagenes(gdata)

# Obtain the set of gene programs from clusters with #genes >= 5
mgs = dict()
for mg, genes in metagenes.items():
if len(genes) > 4:
mgs[mg] = genes



net = net_melted
net['weight'] = net['weight'].astype(str)
Expand Down
Loading

0 comments on commit 0fc738e

Please sign in to comment.