Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ you should jump to {ref}`array_stats_api` and read forward.

arviz_stats.compare
arviz_stats.loo
arviz_stats.loo_i
arviz_stats.loo_approximate_posterior
arviz_stats.loo_kfold
arviz_stats.loo_moment_match
Expand Down
1 change: 1 addition & 0 deletions src/arviz_stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from arviz_stats.accessors import *
from arviz_stats.loo import (
loo,
loo_i,
loo_expectations,
loo_metrics,
loo_pit,
Expand Down
3 changes: 2 additions & 1 deletion src/arviz_stats/loo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Pareto-smoothed importance sampling LOO (PSIS-LOO-CV) and K-fold cross-validation functions."""

from arviz_stats.loo.loo import loo
from arviz_stats.loo.loo import loo, loo_i
from arviz_stats.loo.loo_approximate_posterior import loo_approximate_posterior
from arviz_stats.loo.loo_expectations import loo_expectations, loo_metrics
from arviz_stats.loo.loo_pit import loo_pit
Expand All @@ -13,6 +13,7 @@

__all__ = [
"loo",
"loo_i",
"loo_approximate_posterior",
"loo_expectations",
"loo_metrics",
Expand Down
8 changes: 6 additions & 2 deletions src/arviz_stats/loo/helper_loo.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,19 @@ def _get_log_likelihood_i(log_likelihood, i, obs_dims):
obs_dim = obs_dims[0]
if i < 0 or i >= log_likelihood.sizes[obs_dim]:
raise IndexError(f"Index {i} is out of bounds for dimension '{obs_dim}'.")
log_lik_i = log_likelihood.isel({obs_dim: i})
log_lik_i = log_likelihood.isel({obs_dim: slice(i, i + 1)})
else:
stacked_obs_dim = "__obs__"
log_lik_stacked = log_likelihood.stack({stacked_obs_dim: obs_dims})

if i < 0 or i >= log_lik_stacked.sizes[stacked_obs_dim]:
raise IndexError(
f"Index {i} is out of bounds for stacked dimension '{stacked_obs_dim}'."
)
log_lik_i = log_lik_stacked.isel({stacked_obs_dim: i})

log_lik_i = log_lik_stacked.isel({stacked_obs_dim: slice(i, i + 1)})
log_lik_i = log_lik_i.unstack(stacked_obs_dim)

return log_lik_i


Expand Down
153 changes: 153 additions & 0 deletions src/arviz_stats/loo/loo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Pareto-smoothed importance sampling LOO (PSIS-LOO-CV)."""

import numpy as np
from arviz_base import rcParams
from xarray_einstats.stats import logsumexp

from arviz_stats.loo.helper_loo import (
_compute_loo_results,
_get_log_likelihood_i,
_get_r_eff,
_prepare_loo_inputs,
_warn_pareto_k,
)
from arviz_stats.utils import ELPDData


def loo(
Expand Down Expand Up @@ -162,3 +167,151 @@ def loo(
approx_posterior=False,
log_jacobian=log_jacobian,
)


def loo_i(
i,
data,
var_name=None,
reff=None,
log_weights=None,
pareto_k=None,
):
r"""Compute PSIS-LOO-CV for a single observation.

Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed
importance sampling leave-one-out cross-validation (PSIS-LOO-CV) for a single observation.
The method is described in [1]_ and [2]_.

Parameters
----------
i : int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use idx?

Copy link
Member Author

@jordandeklerk jordandeklerk Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I initially thought that this function would just be for quick testing mainly with custom log-likelihood functions so the actual indexing wouldn't be relevant. But I think maybe we should account for more calculated options here for the user.

The most recent commit allows for several different ways of doing this. Maybe this is overkill though?

Index of the observation for which to compute LOO. Must be between
0 and N-1 where N is the total number of observations.
data : DataTree or InferenceData
Input data. It should contain the posterior and the log_likelihood groups.
var_name : str, optional
The name of the variable in log_likelihood groups storing the pointwise log
likelihood data to use for loo computation.
reff : float, optional
Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number
of actual samples. Computed from trace by default.
log_weights : DataArray, optional
Smoothed log weights for observation i. If not provided, will be computed using PSIS.
Must be provided together with pareto_k or both must be None.
pareto_k : float, optional
Pareto shape value for observation i. If not provided, will be computed using PSIS.
Must be provided together with log_weights or both must be None.

Returns
-------
ELPDData
Object with the following attributes:

- **elpd**: expected log pointwise predictive density for observation i
- **se**: standard error (set to 0.0 as SE is undefined for a single observation)
- **p**: effective number of parameters for observation i
- **n_samples**: number of samples
- **n_data_points**: 1 (single observation)
- **warning**: True if the estimated shape parameter of Pareto distribution is greater
than ``good_k``
- **elpd_i**: :class:`~xarray.DataArray` with single value
- **pareto_k**: :class:`~xarray.DataArray` with single Pareto shape value
- **good_k**: For a sample size S, the threshold is computed as
``min(1 - 1/log10(S), 0.7)``
- **log_weights**: Smoothed log weights for observation i

Notes
-----
This function is useful for testing log-likelihood functions and getting detailed diagnostics
for individual observations. It's particularly helpful when debugging PSIS-LOO-CV computations
for large datasets using :func:`loo_subsample` with the PLPD approximation method, or when
verifying log-likelihood implementations with :func:`loo_moment_match`.

Since this computes PSIS-LOO-CV for a single observation, the standard error is set to 0.0 as
variance cannot be computed from a single value.

Examples
--------
Compute LOO for a single observation:

.. ipython::

In [1]: from arviz_stats import loo_i
...: from arviz_base import load_arviz_data
...: data = load_arviz_data("centered_eight")
...: loo_data_i = loo_i(0, data)
...: loo_data_i

Check the Pareto shape diagnostics for a specific observation:

.. ipython::

In [2]: loo_data_i.pareto_k

See Also
--------
:func:`loo` : Compute LOO for all observations
:func:`compare` : Compare models based on their ELPD.

References
----------

.. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4
arXiv preprint https://arxiv.org/abs/1507.04544.

.. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*.
Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html
arXiv preprint https://arxiv.org/abs/1507.02646
"""
if not isinstance(i, int | np.integer):
raise TypeError(f"i must be an integer, got {type(i)}")

loo_inputs = _prepare_loo_inputs(data, var_name)

i = int(i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You checked that i is integer

if i < 0 or i >= loo_inputs.n_data_points:
raise ValueError(f"Index i must be between 0 and {loo_inputs.n_data_points - 1}, got {i}")

if reff is None:
reff = _get_r_eff(data, loo_inputs.n_samples)

log_lik_i = _get_log_likelihood_i(loo_inputs.log_likelihood, i, loo_inputs.obs_dims)

if (log_weights is None) != (pareto_k is None):
raise ValueError(
"Both log_weights and pareto_k must be provided together or both must be None. "
"Only one was provided."
)

if log_weights is None and pareto_k is None:
log_weights_i, pareto_k_i = log_lik_i.azstats.psislw(r_eff=reff, dim=loo_inputs.sample_dims)
else:
log_weights_i = log_weights
pareto_k_i = pareto_k

log_weights_sum = log_weights_i + log_lik_i

elpd_i = logsumexp(log_weights_sum, dims=loo_inputs.sample_dims).item()
lppd_i = logsumexp(log_lik_i, b=1 / loo_inputs.n_samples, dims=loo_inputs.sample_dims).item()
p_loo_i = lppd_i - elpd_i
elpd_se = 0.0

warn_mg, good_k = _warn_pareto_k(pareto_k_i, loo_inputs.n_samples)

return ELPDData(
kind="loo",
elpd=elpd_i,
se=elpd_se,
p=p_loo_i,
n_samples=loo_inputs.n_samples,
n_data_points=1,
scale="log",
warning=warn_mg,
good_k=good_k,
elpd_i=elpd_i,
pareto_k=pareto_k_i,
approx_posterior=False,
log_weights=log_weights_i,
)
26 changes: 26 additions & 0 deletions tests/test_loo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
loo,
loo_approximate_posterior,
loo_expectations,
loo_i,
loo_metrics,
loo_moment_match,
loo_pit,
Expand Down Expand Up @@ -830,6 +831,31 @@ def test_log_weights_reuse(centered_eight):
assert hasattr(metrics, "mean")


def test_loo_i(centered_eight):
loo_full = loo(centered_eight, pointwise=True)

result_0 = loo_i(0, centered_eight)
assert isinstance(result_0, ELPDData)
assert result_0.kind == "loo"
assert result_0.n_data_points == 1
assert result_0.n_samples == 2000
assert_almost_equal(result_0.elpd, loo_full.elpd_i[0].item(), decimal=10)
assert_almost_equal(result_0.pareto_k.item(), loo_full.pareto_k[0].item(), decimal=10)

result_7 = loo_i(7, centered_eight)
assert_almost_equal(result_7.elpd, loo_full.elpd_i[7].item(), decimal=10)
assert_almost_equal(result_7.pareto_k.item(), loo_full.pareto_k[7].item(), decimal=10)

with pytest.raises(ValueError, match="Index i must be between"):
loo_i(-1, centered_eight)

with pytest.raises(ValueError, match="Index i must be between"):
loo_i(8, centered_eight)

with pytest.raises(TypeError, match="i must be an integer"):
loo_i(3.5, centered_eight)


def test_loo_jacobian(centered_eight):
loo_no_jacobian = loo(centered_eight, pointwise=True)

Expand Down