diff --git a/cell2location/models/_cell2location_module.py b/cell2location/models/_cell2location_module.py index ab705ee8..3ac7cc1c 100755 --- a/cell2location/models/_cell2location_module.py +++ b/cell2location/models/_cell2location_module.py @@ -350,7 +350,9 @@ def forward(self, x_data, idx, batch_index): # cell state signatures (e.g. background, free-floating RNA) s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", - dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta), + dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta) + .expand([1, 1]) + .to_event(2), ) s_g_gene_add_mean = pyro.sample( "s_g_gene_add_mean", @@ -377,7 +379,7 @@ def forward(self, x_data, idx, batch_index): # =====================Gene-specific overdispersion ======================= # alpha_g_phi_hyp = pyro.sample( "alpha_g_phi_hyp", - dist.Gamma(self.alpha_g_phi_hyp_prior_alpha, self.alpha_g_phi_hyp_prior_beta), + dist.Gamma(self.alpha_g_phi_hyp_prior_alpha, self.alpha_g_phi_hyp_prior_beta).expand([1, 1]).to_event(2), ) alpha_g_inverse = pyro.sample( "alpha_g_inverse", @@ -389,10 +391,6 @@ def forward(self, x_data, idx, batch_index): # expected expression mu = ((w_sf @ self.cell_state) * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) - # convert mean and overdispersion to total count and logits - # total_count, logits = _convert_mean_disp_to_counts_logits( - # mu, alpha, eps=self.eps - # ) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial @@ -400,7 +398,6 @@ def forward(self, x_data, idx, batch_index): pyro.sample( "data_target", dist.GammaPoisson(concentration=alpha, rate=alpha / mu), - # dist.NegativeBinomial(total_count=total_count, logits=logits), obs=x_data, ) diff --git a/cell2location/models/base/_pyro_mixin.py b/cell2location/models/base/_pyro_mixin.py index 39e1731c..da44c9cd 100755 --- a/cell2location/models/base/_pyro_mixin.py +++ b/cell2location/models/base/_pyro_mixin.py @@ -1,23 +1,100 @@ from datetime import date from functools import partial +from typing import Callable, Tuple, Union import matplotlib import matplotlib.pyplot as plt import numpy as np import pandas as pd import pyro +import pyro.distributions as dist import torch from pyro import poutine -from pyro.infer.autoguide import AutoNormal, init_to_mean +from pyro.distributions.distribution import Distribution +from pyro.infer.autoguide import AutoNormal +from pyro.infer.autoguide import AutoNormalMessenger as AutoNormalMessengerPyro +from pyro.infer.autoguide import init_to_feasible, init_to_mean +from pyro.infer.autoguide.utils import helpful_support_errors from scipy.sparse import issparse from scvi import _CONSTANTS from scvi.data._anndata import get_from_registry from scvi.dataloaders import AnnDataLoader from scvi.model._utils import parse_use_gpu_arg +from torch.distributions import biject_to from ...distributions.AutoNormalEncoder import AutoGuideList, AutoNormalEncoder +class AutoNormalMessenger(AutoNormalMessengerPyro): + """ + :class:`AutoMessenger` with mean-field normal posterior. + + Copied from Pyro with modifications adding quantile methods. + + The mean-field posterior at any site is a transformed normal distribution. + This posterior is equivalent to :class:`~pyro.infer.autoguide.AutoNormal` + or :class:`~pyro.infer.autoguide.AutoDiagonalNormal`, but allows + customization via subclassing. + + :param callable model: A Pyro model. + :param callable init_loc_fn: A per-site initialization function. + See :ref:`autoguide-initialization` section for available functions. + :param float init_scale: Initial scale for the standard deviation of each + (unconstrained transformed) latent variable. + :param tuple amortized_plates: A tuple of names of plates over which guide + parameters should be shared. This is useful for subsampling, where a + guide parameter can be shared across all plates. + """ + + def __init__( + self, + model: Callable, + *, + init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible), + init_scale: float = 0.1, + amortized_plates: Tuple[str, ...] = (), + ): + if not isinstance(init_scale, float) or not (init_scale > 0): + raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) + super().__init__(model, amortized_plates=amortized_plates) + self.init_loc_fn = init_loc_fn + self._init_scale = init_scale + self._computing_median = False + self._computing_quantiles = False + self._quantile_values = None + + def get_posterior(self, name: str, prior: Distribution) -> Union[Distribution, torch.Tensor]: + if self._computing_median: + return self._get_posterior_median(name, prior) + if self._computing_quantiles: + return self._get_posterior_quantiles(name, prior) + + with helpful_support_errors({"name": name, "fn": prior}): + transform = biject_to(prior.support) + loc, scale = self._get_params(name, prior) + posterior = dist.TransformedDistribution( + dist.Normal(loc, scale).to_event(transform.domain.event_dim), + transform.with_cache(), + ) + return posterior + + def quantiles(self, quantiles, *args, **kwargs): + self._computing_quantiles = True + self._quantile_values = quantiles + try: + return self(*args, **kwargs) + finally: + self._computing_quantiles = False + + @torch.no_grad() + def _get_posterior_quantiles(self, name, prior): + transform = biject_to(prior.support) + loc, scale = self._get_params(name, prior) + site_quantiles = torch.tensor(self._quantile_values, dtype=loc.dtype, device=loc.device) + site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles) + return transform(site_quantiles_values) + + def init_to_value(site=None, values={}): if site is None: return partial(init_to_value, values=values) @@ -50,10 +127,10 @@ def _create_autoguide( ): if not amortised: - _guide = AutoNormal( + _guide = AutoNormalMessenger( model, init_loc_fn=init_loc_fn, - create_plates=model.create_plates, + # create_plates=model.create_plates, ) else: encoder_kwargs = encoder_kwargs if isinstance(encoder_kwargs, dict) else dict() diff --git a/cell2location/models/reference/_reference_module.py b/cell2location/models/reference/_reference_module.py index 0db4a875..d7e1a58e 100755 --- a/cell2location/models/reference/_reference_module.py +++ b/cell2location/models/reference/_reference_module.py @@ -232,7 +232,9 @@ def forward(self, x_data, idx, batch_index, label_index, extra_categoricals): # s_{e,g} accounting for background, free-floating RNA s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", - dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta), + dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta) + .expand([1, 1]) + .to_event(2), ) s_g_gene_add_mean = pyro.sample( "s_g_gene_add_mean", @@ -259,7 +261,7 @@ def forward(self, x_data, idx, batch_index, label_index, extra_categoricals): # =====================Gene-specific overdispersion ======================= # alpha_g_phi_hyp = pyro.sample( "alpha_g_phi_hyp", - dist.Gamma(self.alpha_g_phi_hyp_prior_alpha, self.alpha_g_phi_hyp_prior_beta), + dist.Gamma(self.alpha_g_phi_hyp_prior_alpha, self.alpha_g_phi_hyp_prior_beta).expand([1, 1]).to_event(2), ) alpha_g_inverse = pyro.sample( "alpha_g_inverse", @@ -277,9 +279,6 @@ def forward(self, x_data, idx, batch_index, label_index, extra_categoricals): if self.n_extra_categoricals is not None: # gene-specific normalisation for covatiates mu = mu * (obs2extra_categoricals @ detection_tech_gene_tg) - # total_count, logits = _convert_mean_disp_to_counts_logits( - # mu, alpha, eps=self.eps - # ) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial @@ -287,7 +286,6 @@ def forward(self, x_data, idx, batch_index, label_index, extra_categoricals): pyro.sample( "data_target", dist.GammaPoisson(concentration=alpha, rate=alpha / mu), - # dist.NegativeBinomial(total_count=total_count, logits=logits), obs=x_data, ) diff --git a/cell2location/models/simplified/_cell2location_v3_no_factorisation_module.py b/cell2location/models/simplified/_cell2location_v3_no_factorisation_module.py index 07fc1995..b7176b66 100755 --- a/cell2location/models/simplified/_cell2location_v3_no_factorisation_module.py +++ b/cell2location/models/simplified/_cell2location_v3_no_factorisation_module.py @@ -250,12 +250,14 @@ def forward(self, x_data, idx, batch_index): dist.Gamma( self.N_cells_per_location * self.N_cells_mean_var_ratio, self.N_cells_mean_var_ratio, - ), + ) + .expand([1, 1]) + .to_event(2), ) a_factors_per_location = pyro.sample( "a_factors_per_location", - dist.Gamma(self.A_factors_per_location, self.ones), + dist.Gamma(self.A_factors_per_location, self.ones).expand([1, 1]).to_event(2), ) # cell group loadings @@ -299,7 +301,9 @@ def forward(self, x_data, idx, batch_index): # cell state signatures (e.g. background, free-floating RNA) s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", - dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta), + dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta) + .expand([1, 1]) + .to_event(2), ) s_g_gene_add_mean = pyro.sample( "s_g_gene_add_mean", @@ -326,7 +330,7 @@ def forward(self, x_data, idx, batch_index): # =====================Gene-specific overdispersion ======================= # alpha_g_phi_hyp = pyro.sample( "alpha_g_phi_hyp", - dist.Gamma(self.alpha_g_phi_hyp_prior_alpha, self.alpha_g_phi_hyp_prior_beta), + dist.Gamma(self.alpha_g_phi_hyp_prior_alpha, self.alpha_g_phi_hyp_prior_beta).expand([1, 1]).to_event(2), ) alpha_g_inverse = pyro.sample( "alpha_g_inverse", @@ -337,10 +341,6 @@ def forward(self, x_data, idx, batch_index): # expected expression mu = ((w_sf @ self.cell_state) * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) - # convert mean and overdispersion to total count and logits - # total_count, logits = _convert_mean_disp_to_counts_logits( - # mu, alpha, eps=self.eps - # ) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial @@ -348,7 +348,6 @@ def forward(self, x_data, idx, batch_index): pyro.sample( "data_target", dist.GammaPoisson(concentration=alpha, rate=alpha / mu), - # dist.NegativeBinomial(total_count=total_count, logits=logits), obs=x_data, ) diff --git a/cell2location/models/simplified/_cell2location_v3_no_mg_module.py b/cell2location/models/simplified/_cell2location_v3_no_mg_module.py index 6152c5a7..cf32d261 100755 --- a/cell2location/models/simplified/_cell2location_v3_no_mg_module.py +++ b/cell2location/models/simplified/_cell2location_v3_no_mg_module.py @@ -288,7 +288,9 @@ def forward(self, x_data, idx, batch_index): # cell state signatures (e.g. background, free-floating RNA) s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", - dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta), + dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta) + .expand([1, 1]) + .to_event(2), ) s_g_gene_add_mean = pyro.sample( "s_g_gene_add_mean", @@ -315,7 +317,7 @@ def forward(self, x_data, idx, batch_index): # =====================Gene-specific overdispersion ======================= # alpha_g_phi_hyp = pyro.sample( "alpha_g_phi_hyp", - dist.Gamma(self.alpha_g_phi_hyp_prior_alpha, self.alpha_g_phi_hyp_prior_beta), + dist.Gamma(self.alpha_g_phi_hyp_prior_alpha, self.alpha_g_phi_hyp_prior_beta).expand([1, 1]).to_event(2), ) alpha_g_inverse = pyro.sample( "alpha_g_inverse", @@ -326,10 +328,6 @@ def forward(self, x_data, idx, batch_index): # expected expression mu = ((w_sf @ self.cell_state) + (obs2sample @ s_g_gene_add)) * detection_y_s alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) - # convert mean and overdispersion to total count and logits - # total_count, logits = _convert_mean_disp_to_counts_logits( - # mu, alpha, eps=self.eps - # ) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial @@ -337,7 +335,6 @@ def forward(self, x_data, idx, batch_index): pyro.sample( "data_target", dist.GammaPoisson(concentration=alpha, rate=alpha / mu), - # dist.NegativeBinomial(total_count=total_count, logits=logits), obs=x_data, )