Skip to content

Commit 0a6c0fc

Browse files
author
Hylke Donker
committed
Fixes #400
1 parent 35e1217 commit 0a6c0fc

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

dynamax/linear_gaussian_ssm/models.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
from fastprogress.fastprogress import progress_bar
99
from functools import partial
10-
from jax import jit
10+
from jax import jit, tree, vmap
1111
from jax.tree_util import tree_map
1212
from jaxtyping import Array, Float
1313
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
1414
from typing import Any, Optional, Tuple, Union, runtime_checkable
15-
from typing_extensions import Protocol
15+
from typing_extensions import Protocol
1616

1717
from dynamax.ssm import SSM
1818
from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample
@@ -24,7 +24,7 @@
2424
from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW
2525
from dynamax.utils.distributions import NormalInverseWishart as NIW
2626
from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update
27-
from dynamax.utils.utils import pytree_stack, psd_solve
27+
from dynamax.utils.utils import ensure_array_has_batch_dim, pytree_stack, psd_solve
2828

2929
@runtime_checkable
3030
class SuffStatsLGSSM(Protocol):
@@ -206,7 +206,7 @@ def sample(self,
206206
key: PRNGKeyT,
207207
num_timesteps: int,
208208
inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None) \
209-
-> Tuple[Float[Array, "num_timesteps state_dim"],
209+
-> Tuple[Float[Array, "num_timesteps state_dim"],
210210
Float[Array, "num_timesteps emission_dim"]]:
211211
"""Sample from the model.
212212
@@ -607,18 +607,20 @@ def fit_blocked_gibbs(self,
607607
Returns:
608608
parameter object, where each field has `sample_size` copies as leading batch dimension.
609609
"""
610-
num_timesteps = len(emissions)
610+
batch_emissions = ensure_array_has_batch_dim(emissions, self.emission_shape)
611+
batch_inputs = ensure_array_has_batch_dim(inputs, self.inputs_shape)
611612

612-
if inputs is None:
613-
inputs = jnp.zeros((num_timesteps, 0))
613+
num_batches, num_timesteps = batch_emissions.shape[:2]
614+
615+
if batch_inputs is None:
616+
batch_inputs = jnp.zeros((num_batches, num_timesteps, 0))
614617

615-
def sufficient_stats_from_sample(states):
618+
def sufficient_stats_from_sample(y, inputs, states):
616619
"""Convert samples of states to sufficient statistics."""
617620
inputs_joint = jnp.concatenate((inputs, jnp.ones((num_timesteps, 1))), axis=1)
618621
# Let xn[t] = x[t+1] for t = 0...T-2
619622
x, xp, xn = states, states[:-1], states[1:]
620623
u, up = inputs_joint, inputs_joint[:-1]
621-
y = emissions
622624

623625
init_stats = (x[0], jnp.outer(x[0], x[0]), 1)
624626

@@ -678,9 +680,13 @@ def one_sample(_params, rng):
678680
"""Sample a single set of states and compute their sufficient stats."""
679681
rngs = jr.split(rng, 2)
680682
# Sample latent states
681-
states = lgssm_posterior_sample(rngs[0], _params, emissions, inputs)
683+
batch_keys = jr.split(rngs[0], num=num_batches)
684+
forward_backward_batched = vmap(partial(lgssm_posterior_sample, params=_params))
685+
batch_states = forward_backward_batched(batch_keys, emissions=batch_emissions, inputs=batch_inputs)
686+
_batch_stats = vmap(sufficient_stats_from_sample)(batch_emissions, batch_inputs, batch_states)
687+
# Aggregate statistics from all observations.
688+
_stats = tree.map(lambda x: jnp.sum(x, axis=0), _batch_stats)
682689
# Sample parameters
683-
_stats = sufficient_stats_from_sample(states)
684690
return lgssm_params_sample(rngs[1], _stats)
685691

686692

dynamax/linear_gaussian_ssm/models_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
"""
22
Tests for the linear Gaussian SSM models.
33
"""
4+
from functools import partial
5+
from itertools import count
46

57
import pytest
8+
from jax import vmap
9+
import jax.numpy as jnp
610
import jax.random as jr
711

812
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
@@ -29,3 +33,20 @@ def test_sample_and_fit(cls, kwargs, inputs):
2933
fitted_params, lps = model.fit_em(params, param_props, emissions, inputs=inputs, num_iters=3)
3034
assert monotonically_increasing(lps)
3135
fitted_params, lps = model.fit_sgd(params, param_props, emissions, inputs=inputs, num_epochs=3)
36+
37+
def test_fit_blocked_gibbs_batched():
38+
"""
39+
Test that the blocked Gibbs sampler works for multiple observations.
40+
"""
41+
state_dim = 2
42+
emission_dim = 3
43+
num_timesteps = 4
44+
m_samples = 5
45+
keys = map(jr.PRNGKey, count())
46+
m_keys = jr.split(next(keys), num=m_samples)
47+
48+
model = LinearGaussianConjugateSSM(state_dim, emission_dim)
49+
params, _ = model.initialize(next(keys))
50+
_, y_obs = vmap(partial(model.sample, params, num_timesteps=num_timesteps))(m_keys)
51+
52+
model.fit_blocked_gibbs(next(keys), params, sample_size=6, emissions=y_obs)

0 commit comments

Comments
 (0)