diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 0e2cdb63..9403bffe 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -165,8 +165,8 @@ function AbstractMCMC.step( # Compute next transition and state. state = HMCState(0, t, metric, κ, adaptor) - # Return the initial transition and state. - return Transition(t.z, merge(stat(t), (is_adapt = false,))), state + # Take actual first step. + return AbstractMCMC.step(rng, model, spl, state; kwargs...) end function AbstractMCMC.step( @@ -260,13 +260,10 @@ function (cb::HMCProgressCallback)( κ = state.κ tstat = t.stat isadapted = tstat.is_adapt - # The initial transition will not much information beyond the `is_adapt` field. - if haskey(tstat, :numerical_error) - if isadapted - cb.num_divergent_transitions_during_adaption[] += tstat.numerical_error - else - cb.num_divergent_transitions[] += tstat.numerical_error - end + if isadapted + cb.num_divergent_transitions_during_adaption[] += tstat.numerical_error + else + cb.num_divergent_transitions[] += tstat.numerical_error end # Update progress meter