diff --git a/tests/smc/test_pretuning.py b/tests/smc/test_pretuning.py index a677c99ae..d24996eaf 100644 --- a/tests/smc/test_pretuning.py +++ b/tests/smc/test_pretuning.py @@ -145,13 +145,85 @@ def test_update_multi_sigmas(self): ) +def tuned_adaptive_tempered_inference_loop(kernel, rng_key, initial_state): + def cond(carry): + _, state, *_ = carry + return state.sampler_state.lmbda < 1 + + def body(carry): + i, state, curr_loglikelihood = carry + subkey = jax.random.fold_in(rng_key, i) + state, info = kernel(subkey, state) + return i + 1, state, curr_loglikelihood + info.log_likelihood_increment + + total_iter, final_state, log_likelihood = jax.lax.while_loop( + cond, body, (0, initial_state, 0.0) + ) + return final_state + + class PretuningSMCTest(SMCLinearRegressionTestCase): def setUp(self): super().setUp() self.key = jax.random.key(42) @chex.variants(with_jit=True) - def test_linear_regression(self): + def test_tempered(self): + step_provider = lambda logprior_fn, loglikelihood_fn, pretune: blackjax.smc.pretuning.build_kernel( + blackjax.tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + pretune_fn=pretune, + ) + + def loop(smc_kernel, init_particles, initial_parameters): + initial_state = init( + blackjax.tempered_smc.init, init_particles, initial_parameters + ) + + def body_fn(carry, lmbda): + i, state = carry + subkey = jax.random.fold_in(self.key, i) + new_state, info = smc_kernel(subkey, state, lmbda=lmbda) + return (i + 1, new_state), (new_state, info) + + num_tempering_steps = 10 + lambda_schedule = np.logspace(-5, 0, num_tempering_steps) + + (_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule) + return result + + self.linear_regression_test_case(step_provider, loop) + + @chex.variants(with_jit=True) + def test_adaptive_tempered(self): + step_provider = lambda logprior_fn, loglikelihood_fn, pretune: blackjax.smc.pretuning.build_kernel( + blackjax.adaptive_tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + pretune_fn=pretune, + target_ess=0.5, + ) + + def loop(smc_kernel, init_particles, initial_parameters): + initial_state = init( + blackjax.tempered_smc.init, init_particles, initial_parameters + ) + return tuned_adaptive_tempered_inference_loop( + smc_kernel, self.key, initial_state + ) + + self.linear_regression_test_case(step_provider, loop) + + def linear_regression_test_case(self, step_provider, loop): ( init_particles, logprior_fn, @@ -191,32 +263,11 @@ def test_linear_regression(self): positive_parameters=["step_size"], ) - step = blackjax.smc.pretuning.build_kernel( - blackjax.tempered_smc, - logprior_fn, - loglikelihood_fn, - blackjax.hmc.build_kernel(), - blackjax.hmc.init, - resampling.systematic, - num_mcmc_steps=10, - pretune_fn=pretune, - ) + step = step_provider(logprior_fn, loglikelihood_fn, pretune) - initial_state = init( - blackjax.tempered_smc.init, init_particles, initial_parameters - ) smc_kernel = self.variant(step) - def body_fn(carry, lmbda): - i, state = carry - subkey = jax.random.fold_in(self.key, i) - new_state, info = smc_kernel(subkey, state, lmbda=lmbda) - return (i + 1, new_state), (new_state, info) - - num_tempering_steps = 10 - lambda_schedule = np.logspace(-5, 0, num_tempering_steps) - - (_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule) + result = loop(smc_kernel, init_particles, initial_parameters) self.assert_linear_regression_test_case(result.sampler_state) assert set(result.parameter_override.keys()) == { "step_size",