diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index e9496621..4e9ca04e 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -7,12 +7,12 @@ from fastprogress.fastprogress import progress_bar from functools import partial -from jax import jit, vmap +from jax import jit, tree, vmap from jax.tree_util import tree_map from jaxtyping import Array, Float from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN from typing import Any, Optional, Tuple, Union, runtime_checkable -from typing_extensions import Protocol +from typing_extensions import Protocol from dynamax.ssm import SSM from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample @@ -24,7 +24,7 @@ from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW from dynamax.utils.distributions import NormalInverseWishart as NIW from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update -from dynamax.utils.utils import pytree_stack, psd_solve +from dynamax.utils.utils import ensure_array_has_batch_dim, pytree_stack, psd_solve @runtime_checkable class SuffStatsLGSSM(Protocol): @@ -206,7 +206,7 @@ def sample(self, key: PRNGKeyT, num_timesteps: int, inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None) \ - -> Tuple[Float[Array, "num_timesteps state_dim"], + -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: """Sample from the model. @@ -357,7 +357,7 @@ def forecast(self, input_weights=params.emissions.input_weights, cov=1e8 * jnp.ones(self.emission_dim)) # ignore dummy observatiosn ) - + dummy_emissions = jnp.zeros((num_forecast_timesteps, self.emission_dim)) forecast_inputs = forecast_inputs if forecast_inputs is not None else \ jnp.zeros((num_forecast_timesteps, 0)) @@ -367,7 +367,7 @@ def forecast(self, H = params.emissions.weights b = params.emissions.bias R = params.emissions.cov if params.emissions.cov.ndim == 2 else jnp.diag(params.emissions.cov) - + forecast_emissions = forecast_states.filtered_means @ H.T + b forecast_emissions_cov = H @ forecast_states.filtered_covariances @ H.T + R return forecast_states.filtered_means, \ @@ -662,18 +662,20 @@ def fit_blocked_gibbs(self, Returns: parameter object, where each field has `sample_size` copies as leading batch dimension. """ - num_timesteps = len(emissions) + batch_emissions = ensure_array_has_batch_dim(emissions, self.emission_shape) + batch_inputs = ensure_array_has_batch_dim(inputs, self.inputs_shape) - if inputs is None: - inputs = jnp.zeros((num_timesteps, 0)) + num_batches, num_timesteps = batch_emissions.shape[:2] + + if batch_inputs is None: + batch_inputs = jnp.zeros((num_batches, num_timesteps, 0)) - def sufficient_stats_from_sample(states): + def sufficient_stats_from_sample(y, inputs, states): """Convert samples of states to sufficient statistics.""" inputs_joint = jnp.concatenate((inputs, jnp.ones((num_timesteps, 1))), axis=1) # Let xn[t] = x[t+1] for t = 0...T-2 x, xp, xn = states, states[:-1], states[1:] u, up = inputs_joint, inputs_joint[:-1] - y = emissions init_stats = (x[0], jnp.outer(x[0], x[0]), 1) @@ -733,9 +735,13 @@ def one_sample(_params, rng): """Sample a single set of states and compute their sufficient stats.""" rngs = jr.split(rng, 2) # Sample latent states - states = lgssm_posterior_sample(rngs[0], _params, emissions, inputs) + batch_keys = jr.split(rngs[0], num=num_batches) + forward_backward_batched = vmap(partial(lgssm_posterior_sample, params=_params)) + batch_states = forward_backward_batched(batch_keys, emissions=batch_emissions, inputs=batch_inputs) + _batch_stats = vmap(sufficient_stats_from_sample)(batch_emissions, batch_inputs, batch_states) + # Aggregate statistics from all observations. + _stats = tree.map(lambda x: jnp.sum(x, axis=0), _batch_stats) # Sample parameters - _stats = sufficient_stats_from_sample(states) return lgssm_params_sample(rngs[1], _stats) diff --git a/dynamax/linear_gaussian_ssm/models_test.py b/dynamax/linear_gaussian_ssm/models_test.py index 50b5aff8..bd156742 100644 --- a/dynamax/linear_gaussian_ssm/models_test.py +++ b/dynamax/linear_gaussian_ssm/models_test.py @@ -1,8 +1,12 @@ """ Tests for the linear Gaussian SSM models. """ +from functools import partial +from itertools import count import pytest +from jax import vmap +import jax.numpy as jnp import jax.random as jr from dynamax.linear_gaussian_ssm import LinearGaussianSSM @@ -29,3 +33,20 @@ def test_sample_and_fit(cls, kwargs, inputs): fitted_params, lps = model.fit_em(params, param_props, emissions, inputs=inputs, num_iters=3) assert monotonically_increasing(lps) fitted_params, lps = model.fit_sgd(params, param_props, emissions, inputs=inputs, num_epochs=3) + +def test_fit_blocked_gibbs_batched(): + """ + Test that the blocked Gibbs sampler works for multiple observations. + """ + state_dim = 2 + emission_dim = 3 + num_timesteps = 4 + m_samples = 5 + keys = map(jr.PRNGKey, count()) + m_keys = jr.split(next(keys), num=m_samples) + + model = LinearGaussianConjugateSSM(state_dim, emission_dim) + params, _ = model.initialize(next(keys)) + _, y_obs = vmap(partial(model.sample, params, num_timesteps=num_timesteps))(m_keys) + + model.fit_blocked_gibbs(next(keys), params, sample_size=6, emissions=y_obs) \ No newline at end of file