|
7 | 7 |
|
8 | 8 | from fastprogress.fastprogress import progress_bar
|
9 | 9 | from functools import partial
|
10 |
| -from jax import jit |
| 10 | +from jax import jit, tree, vmap |
11 | 11 | from jax.tree_util import tree_map
|
12 | 12 | from jaxtyping import Array, Float
|
13 | 13 | from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
|
14 | 14 | from typing import Any, Optional, Tuple, Union, runtime_checkable
|
15 |
| -from typing_extensions import Protocol |
| 15 | +from typing_extensions import Protocol |
16 | 16 |
|
17 | 17 | from dynamax.ssm import SSM
|
18 | 18 | from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample
|
|
24 | 24 | from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW
|
25 | 25 | from dynamax.utils.distributions import NormalInverseWishart as NIW
|
26 | 26 | from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update
|
27 |
| -from dynamax.utils.utils import pytree_stack, psd_solve |
| 27 | +from dynamax.utils.utils import ensure_array_has_batch_dim, pytree_stack, psd_solve |
28 | 28 |
|
29 | 29 | @runtime_checkable
|
30 | 30 | class SuffStatsLGSSM(Protocol):
|
@@ -206,7 +206,7 @@ def sample(self,
|
206 | 206 | key: PRNGKeyT,
|
207 | 207 | num_timesteps: int,
|
208 | 208 | inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None) \
|
209 |
| - -> Tuple[Float[Array, "num_timesteps state_dim"], |
| 209 | + -> Tuple[Float[Array, "num_timesteps state_dim"], |
210 | 210 | Float[Array, "num_timesteps emission_dim"]]:
|
211 | 211 | """Sample from the model.
|
212 | 212 |
|
@@ -607,18 +607,20 @@ def fit_blocked_gibbs(self,
|
607 | 607 | Returns:
|
608 | 608 | parameter object, where each field has `sample_size` copies as leading batch dimension.
|
609 | 609 | """
|
610 |
| - num_timesteps = len(emissions) |
| 610 | + batch_emissions = ensure_array_has_batch_dim(emissions, self.emission_shape) |
| 611 | + batch_inputs = ensure_array_has_batch_dim(inputs, self.inputs_shape) |
611 | 612 |
|
612 |
| - if inputs is None: |
613 |
| - inputs = jnp.zeros((num_timesteps, 0)) |
| 613 | + num_batches, num_timesteps = batch_emissions.shape[:2] |
| 614 | + |
| 615 | + if batch_inputs is None: |
| 616 | + batch_inputs = jnp.zeros((num_batches, num_timesteps, 0)) |
614 | 617 |
|
615 |
| - def sufficient_stats_from_sample(states): |
| 618 | + def sufficient_stats_from_sample(y, inputs, states): |
616 | 619 | """Convert samples of states to sufficient statistics."""
|
617 | 620 | inputs_joint = jnp.concatenate((inputs, jnp.ones((num_timesteps, 1))), axis=1)
|
618 | 621 | # Let xn[t] = x[t+1] for t = 0...T-2
|
619 | 622 | x, xp, xn = states, states[:-1], states[1:]
|
620 | 623 | u, up = inputs_joint, inputs_joint[:-1]
|
621 |
| - y = emissions |
622 | 624 |
|
623 | 625 | init_stats = (x[0], jnp.outer(x[0], x[0]), 1)
|
624 | 626 |
|
@@ -678,9 +680,13 @@ def one_sample(_params, rng):
|
678 | 680 | """Sample a single set of states and compute their sufficient stats."""
|
679 | 681 | rngs = jr.split(rng, 2)
|
680 | 682 | # Sample latent states
|
681 |
| - states = lgssm_posterior_sample(rngs[0], _params, emissions, inputs) |
| 683 | + batch_keys = jr.split(rngs[0], num=num_batches) |
| 684 | + forward_backward_batched = vmap(partial(lgssm_posterior_sample, params=_params)) |
| 685 | + batch_states = forward_backward_batched(batch_keys, emissions=batch_emissions, inputs=batch_inputs) |
| 686 | + _batch_stats = vmap(sufficient_stats_from_sample)(batch_emissions, batch_inputs, batch_states) |
| 687 | + # Aggregate statistics from all observations. |
| 688 | + _stats = tree.map(lambda x: jnp.sum(x, axis=0), _batch_stats) |
682 | 689 | # Sample parameters
|
683 |
| - _stats = sufficient_stats_from_sample(states) |
684 | 690 | return lgssm_params_sample(rngs[1], _stats)
|
685 | 691 |
|
686 | 692 |
|
|
0 commit comments