From ae3996e15c42aad711af77c039cdea1f5a92d55b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 23 Feb 2025 00:19:59 +0100 Subject: [PATCH 1/5] Fix fields with abstract types --- src/adaptation/stepsize.jl | 4 ++-- src/constructors.jl | 6 +++--- src/trajectory.jl | 20 ++++++++++---------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/adaptation/stepsize.jl b/src/adaptation/stepsize.jl index 9e45f33a..ebfdb363 100644 --- a/src/adaptation/stepsize.jl +++ b/src/adaptation/stepsize.jl @@ -71,12 +71,12 @@ References Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn Sampler: adaptively setting path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research, 15(1), 1593-1623. Nesterov, Y. (2009). Primal-dual subgradient methods for convex problems. Mathematical programming, 120(1), 221-259. """ -struct NesterovDualAveraging{T<:AbstractFloat} <: StepSizeAdaptor +struct NesterovDualAveraging{T<:AbstractFloat,S<:AbstractScalarOrVec{T}} <: StepSizeAdaptor γ::T t_0::T κ::T δ::T - state::DAState{<:AbstractScalarOrVec{T}} + state::DAState{S} end Base.show(io::IO, a::NesterovDualAveraging) = print( io, diff --git a/src/constructors.jl b/src/constructors.jl index 297e9bdd..dbfa1f66 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -144,15 +144,15 @@ For more information, please view the following paper ([arXiv link](https://arxi setting path lengths in Hamiltonian Monte Carlo." Journal of Machine Learning Research 15, no. 1 (2014): 1593-1623. """ -struct HMCDA{T<:Real} <: AbstractHMCSampler +struct HMCDA{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <: AbstractHMCSampler "Target acceptance rate for dual averaging." δ::T "Target leapfrog length." λ::T "Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)" - integrator::Union{Symbol,AbstractIntegrator} + integrator::I "Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption." - metric::Union{Symbol,AbstractMetric} + metric::M end function HMCDA(δ, λ; integrator = :leapfrog, metric = :diagonal) diff --git a/src/trajectory.jl b/src/trajectory.jl index c924b8cc..2eefbcc7 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -99,7 +99,7 @@ It contains the slice variable and the number of acceptable condidates in the tr $(TYPEDFIELDS) """ -struct SliceTS{F<:AbstractFloat} <: AbstractTrajectorySampler +struct SliceTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler "Sampled candidate `PhasePoint`." zcand::PhasePoint "Slice variable in log-space." @@ -120,9 +120,9 @@ It contains the weight of the tree, defined as the total probabilities of the le $(TYPEDFIELDS) """ -struct MultinomialTS{F<:AbstractFloat} <: AbstractTrajectorySampler +struct MultinomialTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler "Sampled candidate `PhasePoint`." - zcand::PhasePoint + zcand::P "Total energy for the given tree, i.e. the sum of energies of all leaves." ℓw::F end @@ -499,13 +499,13 @@ end """ A full binary tree trajectory with only necessary leaves and information stored. """ -struct BinaryTree - zleft::Any # left most leaf node - zright::Any # right most leaf node - ts::Any # turn statistics - sum_α::Any # MH stats, i.e. sum of MH accept prob for all leapfrog steps - nα::Any # total # of leap frog steps, i.e. phase points in a trajectory - ΔH_max::Any # energy in tree with largest absolute different from initial energy +struct BinaryTree{T<:Real,P<:PhasePoint,TS<:TurnStatistic} + zleft::P # left most leaf node + zright::P # right most leaf node + ts::TS # turn statistics + sum_α::T # MH stats, i.e. sum of MH accept prob for all leapfrog steps + nα::Int # total # of leap frog steps, i.e. phase points in a trajectory + ΔH_max::T # energy in tree with largest absolute different from initial energy end """ From b4b58c05105b73a50509927d9da68528c2a95fee Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 23 Feb 2025 00:25:18 +0100 Subject: [PATCH 2/5] Fix format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/constructors.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/constructors.jl b/src/constructors.jl index dbfa1f66..9b3d5ea5 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -144,7 +144,8 @@ For more information, please view the following paper ([arXiv link](https://arxi setting path lengths in Hamiltonian Monte Carlo." Journal of Machine Learning Research 15, no. 1 (2014): 1593-1623. """ -struct HMCDA{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <: AbstractHMCSampler +struct HMCDA{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <: + AbstractHMCSampler "Target acceptance rate for dual averaging." δ::T "Target leapfrog length." From f269cc86dab4bff3cd04f80ec24012b794b96dae Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 23 Feb 2025 00:38:56 +0100 Subject: [PATCH 3/5] Fix tests --- test/trajectory.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/trajectory.jl b/test/trajectory.jl index 2ff8acbe..8ce7a445 100644 --- a/test/trajectory.jl +++ b/test/trajectory.jl @@ -72,7 +72,7 @@ ahmc_isturn(h, z0, z1, rho, v = 1) = AdvancedHMC.isterminated( ClassicNoUTurn(), h, - AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(), 0, 0, 0.0), + AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(), 0.0, 0, 0.0), ).dynamic function hand_isturn_generalised(z0, z1, rho, v = 1) @@ -84,16 +84,16 @@ ahmc_isturn_generalised(h, z0, z1, rho, v = 1) = AdvancedHMC.isterminated( GeneralisedNoUTurn(), h, - AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0), + AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0), ).dynamic function ahmc_isturn_strictgeneralised(h, z0, z1, rho, v = 1) t = AdvancedHMC.isterminated( StrictGeneralisedNoUTurn(), h, - AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0), - AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0, 0, 0.0), - AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0, 0, 0.0), + AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0), + AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0.0, 0, 0.0), + AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0.0, 0, 0.0), ) return t.dynamic end @@ -102,13 +102,13 @@ end Check whether the subtree checks adequately detect U-turns. """ function check_subtree_u_turns(h, z0, z1, rho) - t = AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0) + t = AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0) # The left and right subtree are created in such a way that the # check_left_subtree and check_right_subtree checks should be equivalent # to the general no U-turn check. - tleft = AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0, 0, 0.0) + tleft = AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0.0, 0, 0.0) tright = - AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0, 0, 0.0) + AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0.0, 0, 0.0) s1 = AdvancedHMC.isterminated(GeneralisedNoUTurn(), h, t) s2 = AdvancedHMC.check_left_subtree(h, t, tleft, tright) From 2f16c9a74faecd24fe5ba8fa2c4df1436e8d0a76 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 23 Feb 2025 00:44:18 +0100 Subject: [PATCH 4/5] Fix typo --- src/trajectory.jl | 2 +- test/trajectory.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/trajectory.jl b/src/trajectory.jl index 2eefbcc7..14e3e9ae 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -101,7 +101,7 @@ $(TYPEDFIELDS) """ struct SliceTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler "Sampled candidate `PhasePoint`." - zcand::PhasePoint + zcand::P "Slice variable in log-space." ℓu::F "Number of acceptable candidates, i.e. those with probability larger than slice variable `u`." diff --git a/test/trajectory.jl b/test/trajectory.jl index 8ce7a445..dc71fcf0 100644 --- a/test/trajectory.jl +++ b/test/trajectory.jl @@ -106,7 +106,8 @@ function check_subtree_u_turns(h, z0, z1, rho) # The left and right subtree are created in such a way that the # check_left_subtree and check_right_subtree checks should be equivalent # to the general no U-turn check. - tleft = AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0.0, 0, 0.0) + tleft = + AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0.0, 0, 0.0) tright = AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0.0, 0, 0.0) From c469e0fc08bcf2dc273ef66a47ef2ce2b8c88fd5 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 6 Mar 2025 11:44:56 +0100 Subject: [PATCH 5/5] Update version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index edb687fb..abd6643a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.6.4" +version = "0.7.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"