Skip to content

Commit

Permalink
test in place (#772)
Browse files Browse the repository at this point in the history
  • Loading branch information
ciguaran authored Jan 22, 2025
1 parent 4d4eae0 commit a053bed
Showing 1 changed file with 75 additions and 24 deletions.
99 changes: 75 additions & 24 deletions tests/smc/test_pretuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,85 @@ def test_update_multi_sigmas(self):
)


def tuned_adaptive_tempered_inference_loop(kernel, rng_key, initial_state):
def cond(carry):
_, state, *_ = carry
return state.sampler_state.lmbda < 1

def body(carry):
i, state, curr_loglikelihood = carry
subkey = jax.random.fold_in(rng_key, i)
state, info = kernel(subkey, state)
return i + 1, state, curr_loglikelihood + info.log_likelihood_increment

total_iter, final_state, log_likelihood = jax.lax.while_loop(
cond, body, (0, initial_state, 0.0)
)
return final_state


class PretuningSMCTest(SMCLinearRegressionTestCase):
def setUp(self):
super().setUp()
self.key = jax.random.key(42)

@chex.variants(with_jit=True)
def test_linear_regression(self):
def test_tempered(self):
step_provider = lambda logprior_fn, loglikelihood_fn, pretune: blackjax.smc.pretuning.build_kernel(
blackjax.tempered_smc,
logprior_fn,
loglikelihood_fn,
blackjax.hmc.build_kernel(),
blackjax.hmc.init,
resampling.systematic,
num_mcmc_steps=10,
pretune_fn=pretune,
)

def loop(smc_kernel, init_particles, initial_parameters):
initial_state = init(
blackjax.tempered_smc.init, init_particles, initial_parameters
)

def body_fn(carry, lmbda):
i, state = carry
subkey = jax.random.fold_in(self.key, i)
new_state, info = smc_kernel(subkey, state, lmbda=lmbda)
return (i + 1, new_state), (new_state, info)

num_tempering_steps = 10
lambda_schedule = np.logspace(-5, 0, num_tempering_steps)

(_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule)
return result

self.linear_regression_test_case(step_provider, loop)

@chex.variants(with_jit=True)
def test_adaptive_tempered(self):
step_provider = lambda logprior_fn, loglikelihood_fn, pretune: blackjax.smc.pretuning.build_kernel(
blackjax.adaptive_tempered_smc,
logprior_fn,
loglikelihood_fn,
blackjax.hmc.build_kernel(),
blackjax.hmc.init,
resampling.systematic,
num_mcmc_steps=10,
pretune_fn=pretune,
target_ess=0.5,
)

def loop(smc_kernel, init_particles, initial_parameters):
initial_state = init(
blackjax.tempered_smc.init, init_particles, initial_parameters
)
return tuned_adaptive_tempered_inference_loop(
smc_kernel, self.key, initial_state
)

self.linear_regression_test_case(step_provider, loop)

def linear_regression_test_case(self, step_provider, loop):
(
init_particles,
logprior_fn,
Expand Down Expand Up @@ -191,32 +263,11 @@ def test_linear_regression(self):
positive_parameters=["step_size"],
)

step = blackjax.smc.pretuning.build_kernel(
blackjax.tempered_smc,
logprior_fn,
loglikelihood_fn,
blackjax.hmc.build_kernel(),
blackjax.hmc.init,
resampling.systematic,
num_mcmc_steps=10,
pretune_fn=pretune,
)
step = step_provider(logprior_fn, loglikelihood_fn, pretune)

initial_state = init(
blackjax.tempered_smc.init, init_particles, initial_parameters
)
smc_kernel = self.variant(step)

def body_fn(carry, lmbda):
i, state = carry
subkey = jax.random.fold_in(self.key, i)
new_state, info = smc_kernel(subkey, state, lmbda=lmbda)
return (i + 1, new_state), (new_state, info)

num_tempering_steps = 10
lambda_schedule = np.logspace(-5, 0, num_tempering_steps)

(_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule)
result = loop(smc_kernel, init_particles, initial_parameters)
self.assert_linear_regression_test_case(result.sampler_state)
assert set(result.parameter_override.keys()) == {
"step_size",
Expand Down

0 comments on commit a053bed

Please sign in to comment.