diff --git a/Project.toml b/Project.toml index edb687fb..abd6643a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.6.4" +version = "0.7.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/adaptation/stepsize.jl b/src/adaptation/stepsize.jl index 9e45f33a..ebfdb363 100644 --- a/src/adaptation/stepsize.jl +++ b/src/adaptation/stepsize.jl @@ -71,12 +71,12 @@ References Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn Sampler: adaptively setting path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research, 15(1), 1593-1623. Nesterov, Y. (2009). Primal-dual subgradient methods for convex problems. Mathematical programming, 120(1), 221-259. """ -struct NesterovDualAveraging{T<:AbstractFloat} <: StepSizeAdaptor +struct NesterovDualAveraging{T<:AbstractFloat,S<:AbstractScalarOrVec{T}} <: StepSizeAdaptor γ::T t_0::T κ::T δ::T - state::DAState{<:AbstractScalarOrVec{T}} + state::DAState{S} end Base.show(io::IO, a::NesterovDualAveraging) = print( io, diff --git a/src/constructors.jl b/src/constructors.jl index 297e9bdd..9b3d5ea5 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -144,15 +144,16 @@ For more information, please view the following paper ([arXiv link](https://arxi setting path lengths in Hamiltonian Monte Carlo." Journal of Machine Learning Research 15, no. 1 (2014): 1593-1623. """ -struct HMCDA{T<:Real} <: AbstractHMCSampler +struct HMCDA{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <: + AbstractHMCSampler "Target acceptance rate for dual averaging." δ::T "Target leapfrog length." λ::T "Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)" - integrator::Union{Symbol,AbstractIntegrator} + integrator::I "Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption." - metric::Union{Symbol,AbstractMetric} + metric::M end function HMCDA(δ, λ; integrator = :leapfrog, metric = :diagonal) diff --git a/src/trajectory.jl b/src/trajectory.jl index e8ab1194..23fb3546 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -99,9 +99,9 @@ It contains the slice variable and the number of acceptable condidates in the tr $(TYPEDFIELDS) """ -struct SliceTS{F<:AbstractFloat} <: AbstractTrajectorySampler +struct SliceTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler "Sampled candidate `PhasePoint`." - zcand::PhasePoint + zcand::P "Slice variable in log-space." ℓu::F "Number of acceptable candidates, i.e. those with probability larger than slice variable `u`." @@ -120,9 +120,9 @@ It contains the weight of the tree, defined as the total probabilities of the le $(TYPEDFIELDS) """ -struct MultinomialTS{F<:AbstractFloat} <: AbstractTrajectorySampler +struct MultinomialTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler "Sampled candidate `PhasePoint`." - zcand::PhasePoint + zcand::P "Total energy for the given tree, i.e. the sum of energies of all leaves." ℓw::F end @@ -499,13 +499,13 @@ end """ A full binary tree trajectory with only necessary leaves and information stored. """ -struct BinaryTree - zleft::Any # left most leaf node - zright::Any # right most leaf node - ts::Any # turn statistics - sum_α::Any # MH stats, i.e. sum of MH accept prob for all leapfrog steps - nα::Any # total # of leap frog steps, i.e. phase points in a trajectory - ΔH_max::Any # energy in tree with largest absolute different from initial energy +struct BinaryTree{T<:Real,P<:PhasePoint,TS<:TurnStatistic} + zleft::P # left most leaf node + zright::P # right most leaf node + ts::TS # turn statistics + sum_α::T # MH stats, i.e. sum of MH accept prob for all leapfrog steps + nα::Int # total # of leap frog steps, i.e. phase points in a trajectory + ΔH_max::T # energy in tree with largest absolute different from initial energy end """ diff --git a/test/trajectory.jl b/test/trajectory.jl index 2a7cb310..da6da61e 100644 --- a/test/trajectory.jl +++ b/test/trajectory.jl @@ -72,7 +72,7 @@ ahmc_isturn(h, z0, z1, rho, v = 1) = AdvancedHMC.isterminated( ClassicNoUTurn(), h, - AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(), 0, 0, 0.0), + AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(), 0.0, 0, 0.0), ).dynamic function hand_isturn_generalised(z0, z1, rho, v = 1) @@ -84,16 +84,16 @@ ahmc_isturn_generalised(h, z0, z1, rho, v = 1) = AdvancedHMC.isterminated( GeneralisedNoUTurn(), h, - AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0), + AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0), ).dynamic function ahmc_isturn_strictgeneralised(h, z0, z1, rho, v = 1) t = AdvancedHMC.isterminated( StrictGeneralisedNoUTurn(), h, - AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0), - AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0, 0, 0.0), - AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0, 0, 0.0), + AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0), + AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0.0, 0, 0.0), + AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0.0, 0, 0.0), ) return t.dynamic end @@ -102,13 +102,14 @@ end Check whether the subtree checks adequately detect U-turns. """ function check_subtree_u_turns(h, z0, z1, rho) - t = AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0) + t = AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0) # The left and right subtree are created in such a way that the # check_left_subtree and check_right_subtree checks should be equivalent # to the general no U-turn check. - tleft = AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0, 0, 0.0) + tleft = + AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0.0, 0, 0.0) tright = - AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0, 0, 0.0) + AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0.0, 0, 0.0) s1 = AdvancedHMC.isterminated(GeneralisedNoUTurn(), h, t) s2 = AdvancedHMC.check_left_subtree(h, t, tleft, tright)