Skip to content

Conversation

jordandeklerk
Copy link
Contributor

@jordandeklerk jordandeklerk commented Sep 4, 2025

This implements new functions that support the PLPD approximation method alongside the existing LPD method for PSIS-LOO-CV sub-sampling. Also shifts around some of the helper functions in helper_loo.py to group similar functions together.

The _compute_loo_approximation() function is the biggest change here, facilitating the PLPD and LPD approximations with user provided log-likelihood functions.

Major Changes

New _compute_loo_approximation() function:

  • Streamlined interface: The function now handles both lpd and plpd methods through a single, clean interface and allows a custom log-likelihood function for the LPD approximation as well
  • Better log-likelihood API: More natural interface for user provided log-likelihood functions
  • Improved validation: Added more parameter validation and error catching
  • Better data alignment: Fixed auxiliary data alignment issues when subsampling observations so users can pass extra data that is necessary for log-likelihood computations

When using custom log-likelihood functions with loo_subsample(), the function must take the observed data and the DataTree object as args. For both approximations, I think this is all the user would need to construct a log-likelihood function.

The DataTree should also contain a constant_data group with any extra data like covariates needed to construct the log-likelihood function. The internal helper function will specifically look for this group, and only this group, if it is used to grab data necessary for creating the log-likelihood function (like in the Wells example below):

def log_lik_fn(observations, data):
    """Custom log-likelihood function for LOO approximations.
    
    Parameters
    -----------
    observations : DataArray
        The observed data (from data.observed_data[var_name]) to compute 
        log-likelihood for. 
    
    data : DataTree
        Contains the posterior samples and any auxiliary data.
        - For LPD: data.posterior contains full MCMC samples (chains × draws × parameters)
        - For PLPD: data.posterior contains posterior means (parameters only)
        - data.constant_data contains any auxiliary data (covariates, etc.)
    
    Returns:
    --------
    array-like
        Log-likelihood values with shape:
        - For LPD: (n_chains, n_draws, *observation_shape)
        - For PLPD: observation_shape
    """

I validated this new rework against the loo package using the Wells dataset example from https://mc-stan.org/loo/articles/loo2-large-data.html. You can download the Wells data from the loo package:

import numpy as np
import pandas as pd
import pymc as pm
import xarray as xr
from arviz_base import convert_to_datatree
from scipy import stats
from scipy.special import expit
from arviz_stats import loo_subsample, update_subsample


def log_lik_fun(obs, data):
    beta = data.posterior["beta"].values
    X = data.constant_data["X"].values
    
    logit_pred = X @ beta
    prob = expit(logit_pred)
    return stats.bernoulli.logpmf(obs, prob)

wells = pd.read_csv("wells.csv")
wells["dist100"] = wells["dist"] / 100

X = np.column_stack([
    np.ones(len(wells)),        
    wells["dist100"].values,     
    wells["arsenic"].values 
])
y = wells["switch"].values

with pm.Model():
    beta = pm.Normal("beta", mu=0, sigma=1, shape=3)
    logit_p = pm.math.dot(X, beta)
    pm.Bernoulli("y", logit_p=logit_p, observed=y)
    
    idata = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        random_seed=4711,
        progressbar=True,
    )

data = convert_to_datatree(idata)
data["constant_data"] = xr.Dataset({"X": (["obs_id", "coef"], X)}).assign_coords({
    "obs_id": np.arange(len(X)),
    "coef": ["intercept", "dist100", "arsenic"]
})

# PLPD approximation
loo_plpd = loo_subsample(
    data=data,
    observations=100,
    var_name="y",
    method="plpd",
    log_lik_fn=log_lik_fun,
    param_names=["beta"],
    pointwise=True,
    seed=4711
)

We get output that is nearly identical to the R output for the PLPD approximation (LPD and update_subsample() are also nearly identical as well):

arviz-stats:

Computed from 4000 by 100 subsampled log-likelihood 
values from 3020 total observations.

         Estimate   SE  subsampling SE
elpd_loo   -1968.4 15.6            0.3
p_loo          3.1
------
Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.70]   (good)      100  100.0%
   (0.70, 1]   (bad)         0    0.0%
    (1, Inf)   (very bad)    0    0.0%

loo package:

Computed from 4000 by 100 subsampled log-likelihood
values from 3020 total observations.

         Estimate   SE subsampling SE
elpd_loo  -1968.5 15.6            0.3
p_loo         3.1  0.1            0.4
looic      3936.9 31.2            0.6
------
Monte Carlo SE of elpd_loo is 0.0.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.9, 1.0]).

All Pareto k estimates are good (k < 0.7).
See help('pareto-k-diagnostic') for details.

Resolves #178


📚 Documentation preview 📚: https://arviz-stats--192.org.readthedocs.build/en/192/

@codecov-commenter
Copy link

codecov-commenter commented Sep 4, 2025

Codecov Report

❌ Patch coverage is 69.34866% with 80 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.29%. Comparing base (a40917b) to head (a37cbb7).
⚠️ Report is 118 commits behind head on main.

Files with missing lines Patch % Lines
src/arviz_stats/loo/helper_loo.py 69.11% 80 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #192      +/-   ##
==========================================
+ Coverage   62.64%   70.29%   +7.64%     
==========================================
  Files          14       39      +25     
  Lines        1925     4443    +2518     
==========================================
+ Hits         1206     3123    +1917     
- Misses        719     1320     +601     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jordandeklerk jordandeklerk marked this pull request as ready for review September 5, 2025 01:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rework PLPD Approximation API for loo_subsample()
2 participants