Skip to content

Commit

Permalink
test CI: old tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Jan 21, 2025
1 parent 7974eea commit fccb388
Showing 1 changed file with 7 additions and 106 deletions.
113 changes: 7 additions & 106 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def run_mclmc(
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
num_tuning_integrator_steps,
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand All @@ -148,7 +147,7 @@ def run_mclmc(

return samples

def run_adjusted_mclmc_dynamic(
def run_adjusted_mclmc(
self,
logdensity_fn,
num_steps,
Expand Down Expand Up @@ -179,12 +178,11 @@ def run_adjusted_mclmc_dynamic(
logdensity_fn=logdensity_fn,
)

target_acc_rate = 0.9
target_acc_rate = 0.65

(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
num_tuning_integrator_steps,
) = blackjax.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand Down Expand Up @@ -221,74 +219,6 @@ def run_adjusted_mclmc_dynamic(

return out

def run_adjusted_mclmc(
self,
logdensity_fn,
num_steps,
initial_position,
key,
diagonal_preconditioning=False,
):
integrator = isokinetic_mclachlan

init_key, tune_key, run_key = jax.random.split(key, 3)

initial_state = blackjax.mcmc.adjusted_mclmc.init(
position=initial_position,
logdensity_fn=logdensity_fn,
)

kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel(
integrator=integrator,
inverse_mass_matrix=inverse_mass_matrix,
logdensity_fn=logdensity_fn,
)(
rng_key=rng_key,
state=state,
step_size=step_size,
num_integration_steps=avg_num_integration_steps,
)

target_acc_rate = 0.9

(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
num_tuning_integrator_steps,
) = blackjax.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
target=target_acc_rate,
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.1,
diagonal_preconditioning=diagonal_preconditioning,
)

step_size = blackjax_mclmc_sampler_params.step_size
L = blackjax_mclmc_sampler_params.L

alg = blackjax.adjusted_mclmc(
logdensity_fn=logdensity_fn,
step_size=step_size,
num_integration_steps=L / step_size,
integrator=integrator,
inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix,
)

_, out = run_inference_algorithm(
rng_key=run_key,
initial_state=blackjax_state_after_tuning,
inference_algorithm=alg,
num_steps=num_steps,
transform=lambda state, _: state.position,
progress_bar=False,
)

return out

@parameterized.parameters(
itertools.product(
regression_test_cases, [True, False], window_adaptation_filters
Expand Down Expand Up @@ -405,38 +335,7 @@ def test_mclmc(self):
np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1)

@parameterized.parameters([True, False])
def test_adjusted_mclmc_dynamic(
self,
diagonal_preconditioning,
):
"""Test the MCLMC kernel."""

init_key0, init_key1, inference_key = jax.random.split(self.key, 3)
x_data = jax.random.normal(init_key0, shape=(1000, 1))
y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape)

logposterior_fn_ = functools.partial(
self.regression_logprob, x=x_data, preds=y_data
)
logdensity_fn = lambda x: logposterior_fn_(**x)

states = self.run_adjusted_mclmc_dynamic(
initial_position={"coefs": 1.0, "log_scale": 1.0},
logdensity_fn=logdensity_fn,
key=inference_key,
num_steps=10000,
diagonal_preconditioning=diagonal_preconditioning,
)

coefs_samples = states["coefs"][3000:]
scale_samples = np.exp(states["log_scale"][3000:])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2)

@parameterized.parameters([True, False])
def test_adjusted_mclmc(self, diagonal_preconditioning):
def test_adjusted_mclmc(self):
"""Test the MCLMC kernel."""

init_key0, init_key1, inference_key = jax.random.split(self.key, 3)
Expand All @@ -453,7 +352,6 @@ def test_adjusted_mclmc(self, diagonal_preconditioning):
logdensity_fn=logdensity_fn,
key=inference_key,
num_steps=10000,
diagonal_preconditioning=diagonal_preconditioning,
)

coefs_samples = states["coefs"][3000:]
Expand Down Expand Up @@ -519,7 +417,10 @@ def get_inverse_mass_matrix():
inverse_mass_matrix=inverse_mass_matrix,
)

(_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size(
(
_,
blackjax_mclmc_sampler_params,
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
Expand Down

0 comments on commit fccb388

Please sign in to comment.