Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fields with abstract types #399

Merged
merged 7 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.6.4"
version = "0.7.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
4 changes: 2 additions & 2 deletions src/adaptation/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ References
Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn Sampler: adaptively setting path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research, 15(1), 1593-1623.
Nesterov, Y. (2009). Primal-dual subgradient methods for convex problems. Mathematical programming, 120(1), 221-259.
"""
struct NesterovDualAveraging{T<:AbstractFloat} <: StepSizeAdaptor
struct NesterovDualAveraging{T<:AbstractFloat,S<:AbstractScalarOrVec{T}} <: StepSizeAdaptor
γ::T
t_0::T
κ::T
δ::T
state::DAState{<:AbstractScalarOrVec{T}}
state::DAState{S}
end
Base.show(io::IO, a::NesterovDualAveraging) = print(
io,
Expand Down
7 changes: 4 additions & 3 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,16 @@ For more information, please view the following paper ([arXiv link](https://arxi
setting path lengths in Hamiltonian Monte Carlo." Journal of Machine Learning
Research 15, no. 1 (2014): 1593-1623.
"""
struct HMCDA{T<:Real} <: AbstractHMCSampler
struct HMCDA{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <:
AbstractHMCSampler
"Target acceptance rate for dual averaging."
δ::T
"Target leapfrog length."
λ::T
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::Union{Symbol,AbstractIntegrator}
integrator::I
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::Union{Symbol,AbstractMetric}
metric::M
end

function HMCDA(δ, λ; integrator = :leapfrog, metric = :diagonal)
Expand Down
22 changes: 11 additions & 11 deletions src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ It contains the slice variable and the number of acceptable condidates in the tr

$(TYPEDFIELDS)
"""
struct SliceTS{F<:AbstractFloat} <: AbstractTrajectorySampler
struct SliceTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler
"Sampled candidate `PhasePoint`."
zcand::PhasePoint
zcand::P
"Slice variable in log-space."
ℓu::F
"Number of acceptable candidates, i.e. those with probability larger than slice variable `u`."
Expand All @@ -120,9 +120,9 @@ It contains the weight of the tree, defined as the total probabilities of the le

$(TYPEDFIELDS)
"""
struct MultinomialTS{F<:AbstractFloat} <: AbstractTrajectorySampler
struct MultinomialTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler
"Sampled candidate `PhasePoint`."
zcand::PhasePoint
zcand::P
"Total energy for the given tree, i.e. the sum of energies of all leaves."
ℓw::F
end
Expand Down Expand Up @@ -499,13 +499,13 @@ end
"""
A full binary tree trajectory with only necessary leaves and information stored.
"""
struct BinaryTree
zleft::Any # left most leaf node
zright::Any # right most leaf node
ts::Any # turn statistics
sum_α::Any # MH stats, i.e. sum of MH accept prob for all leapfrog steps
nα::Any # total # of leap frog steps, i.e. phase points in a trajectory
ΔH_max::Any # energy in tree with largest absolute different from initial energy
struct BinaryTree{T<:Real,P<:PhasePoint,TS<:TurnStatistic}
zleft::P # left most leaf node
zright::P # right most leaf node
ts::TS # turn statistics
sum_α::T # MH stats, i.e. sum of MH accept prob for all leapfrog steps
nα::Int # total # of leap frog steps, i.e. phase points in a trajectory
ΔH_max::T # energy in tree with largest absolute different from initial energy
end

"""
Expand Down
17 changes: 9 additions & 8 deletions test/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ ahmc_isturn(h, z0, z1, rho, v = 1) =
AdvancedHMC.isterminated(
ClassicNoUTurn(),
h,
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(), 0, 0, 0.0),
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(), 0.0, 0, 0.0),
).dynamic

function hand_isturn_generalised(z0, z1, rho, v = 1)
Expand All @@ -84,16 +84,16 @@ ahmc_isturn_generalised(h, z0, z1, rho, v = 1) =
AdvancedHMC.isterminated(
GeneralisedNoUTurn(),
h,
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0),
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0),
).dynamic

function ahmc_isturn_strictgeneralised(h, z0, z1, rho, v = 1)
t = AdvancedHMC.isterminated(
StrictGeneralisedNoUTurn(),
h,
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0),
AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0, 0, 0.0),
AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0, 0, 0.0),
AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0),
AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0.0, 0, 0.0),
AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0.0, 0, 0.0),
)
return t.dynamic
end
Expand All @@ -102,13 +102,14 @@ end
Check whether the subtree checks adequately detect U-turns.
"""
function check_subtree_u_turns(h, z0, z1, rho)
t = AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0, 0, 0.0)
t = AdvancedHMC.BinaryTree(z0, z1, AdvancedHMC.TurnStatistic(rho), 0.0, 0, 0.0)
# The left and right subtree are created in such a way that the
# check_left_subtree and check_right_subtree checks should be equivalent
# to the general no U-turn check.
tleft = AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0, 0, 0.0)
tleft =
AdvancedHMC.BinaryTree(z0, z0, AdvancedHMC.TurnStatistic(rho - z1.r), 0.0, 0, 0.0)
tright =
AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0, 0, 0.0)
AdvancedHMC.BinaryTree(z1, z1, AdvancedHMC.TurnStatistic(rho - z0.r), 0.0, 0, 0.0)

s1 = AdvancedHMC.isterminated(GeneralisedNoUTurn(), h, t)
s2 = AdvancedHMC.check_left_subtree(h, t, tleft, tright)
Expand Down
Loading