diff --git a/Project.toml b/Project.toml index 9bb9ac1..cc9f850 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,8 @@ version = "0.5.6" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/AdvancedMH.jl b/src/AdvancedMH.jl index 0c0e64b..4ad7084 100644 --- a/src/AdvancedMH.jl +++ b/src/AdvancedMH.jl @@ -4,12 +4,15 @@ module AdvancedMH using AbstractMCMC using Distributions using Requires +using LinearAlgebra +using PDMats import Random # Exports export MetropolisHastings, DensityModel, RWMH, StaticMH, StaticProposal, - RandomWalkProposal, Ensemble, StretchProposal, MALA + RandomWalkProposal, Ensemble, StretchProposal, MALA, AdaptiveProposal, + AdaptiveMvNormal # Reexports export sample, MCMCThreads, MCMCDistributed @@ -110,5 +113,7 @@ end include("proposal.jl") include("mh-core.jl") include("emcee.jl") +include("adaptive.jl") +include("adaptivemvnormal.jl") end # module AdvancedMH diff --git a/src/adaptive.jl b/src/adaptive.jl new file mode 100644 index 0000000..536e208 --- /dev/null +++ b/src/adaptive.jl @@ -0,0 +1,110 @@ +""" + Adaptor(; tune=25, target=0.44, bound=10., δmax=0.2) + +A helper struct for univariate adaptive proposal kernels. This tracks the +number of accepted proposals and the total number of attempted proposals. The +proposal kernel is tuned every `tune` proposals, such that the scale (log(σ) in +the case of a Normal kernel, log(b) for a Uniform kernel) of the proposal is +increased (decreased) by `δ(n) = min(δmax, 1/√n)` at tuning step `n` if the +estimated acceptance probability is higher (lower) than `target`. The target +acceptance probability defaults to 0.44 which is supposedly optimal for 1D +proposals. To ensure ergodicity, the scale of the proposal has to be bounded +(by `bound`), although this is often not required in practice. +""" +mutable struct Adaptor + accepted::Int + total::Int + tune::Int # tuning interval + target::Float64 # target acceptance rate + bound::Float64 # bound on logσ of Gaussian kernel + δmax::Float64 # maximum adaptation step +end + +function Adaptor(; tune=25, target=0.44, bound=10., δmax=0.2) + return Adaptor(0, 0, tune, target, bound, δmax) +end + +""" + AdaptiveProposal{P} + +An adaptive Metropolis-Hastings proposal. In order for this to work, the +proposal kernel should implement the `adapted(proposal, δ)` method, where `δ` +is the increment/decrement applied to the scale of the proposal distribution +during adaptation (e.g. for a Normal distribution the scale is `log(σ)`, so +that after adaptation the proposal is `Normal(0, exp(log(σ) + δ))`). + +# Example +```julia +julia> p = AdaptiveProposal(Uniform(-0.2, 0.2)); + +julia> rand(p) +0.07975590594518434 +``` + +# References + +Roberts, Gareth O., and Jeffrey S. Rosenthal. "Examples of adaptive MCMC." +Journal of Computational and Graphical Statistics 18.2 (2009): 349-367. +""" +mutable struct AdaptiveProposal{P} <: Proposal{P} + proposal::P + adaptor::Adaptor +end + +function AdaptiveProposal(p; kwargs...) + return AdaptiveProposal(p, Adaptor(; kwargs...)) +end + +# Adaptive proposals are only defined for symmetric proposal distributions +is_symmetric_proposal(::AdaptiveProposal) = true + +accepted!(p::AdaptiveProposal) = p.adaptor.accepted += 1 +accepted!(p::Vector{<:AdaptiveProposal}) = map(accepted!, p) +accepted!(p::NamedTuple{names}) where names = map(x->accepted!(getfield(p, x)), names) + +# this is defined because the first draw has no transition yet (I think) +function propose(rng::Random.AbstractRNG, p::AdaptiveProposal, m::DensityModel) + return rand(rng, p.proposal) +end + +# the actual proposal happens here +function propose( + rng::Random.AbstractRNG, + proposal::AdaptiveProposal{<:Union{Distribution,Proposal}}, + model::DensityModel, + t +) + consider_adaptation!(proposal) + return t + rand(rng, proposal.proposal) +end + +function q(proposal::AdaptiveProposal, t, t_cond) + return logpdf(proposal, t - t_cond) +end + +function consider_adaptation!(p) + (p.adaptor.total % p.adaptor.tune == 0) && adapt!(p) + p.adaptor.total += 1 +end + +function adapt!(p::AdaptiveProposal) + a = p.adaptor + a.total == 0 && return + δ = min(a.δmax, sqrt(a.tune / a.total)) # diminishing adaptation + α = a.accepted / a.tune # acceptance ratio + p_ = adapted(p.proposal, α > a.target ? δ : -δ, a.bound) + a.accepted = 0 + p.proposal = p_ +end + +function adapted(d::Normal, δ, bound=Inf) + _lσ = log(d.σ) + δ + lσ = sign(_lσ) * min(bound, abs(_lσ)) + return Normal(d.μ, exp(lσ)) +end + +function adapted(d::Uniform, δ, bound=Inf) + lσ = log(d.b) + δ + σ = exp(sign(lσ) * min(bound, abs(lσ))) + return Uniform(-σ, σ) +end diff --git a/src/adaptivemvnormal.jl b/src/adaptivemvnormal.jl new file mode 100644 index 0000000..de7734d --- /dev/null +++ b/src/adaptivemvnormal.jl @@ -0,0 +1,82 @@ +""" + AdaptiveMvNormal(constant_component::MvNormal; σ=2.38, β=0.05) + +Adaptive multivariate normal mixture proposal as described in Haario et al. and +Roberts & Rosenthal (2009). Uses a two-component mixture of MvNormal +distributions. One of the components (with mixture weight `β`) remains +constant, while the other component is adapted to the target covariance +structure. The proposal is initialized by providing the constant component to +the constructor. + +`σ` is the scale factor for the covariance matrix, where 2.38 is supposedly +optimal in a high-dimensional context according to Roberts & Rosenthal. + +# References + +- Haario, Heikki, Eero Saksman, and Johanna Tamminen. + "An adaptive Metropolis algorithm." Bernoulli 7.2 (2001): 223-242. +- Roberts, Gareth O., and Jeffrey S. Rosenthal. "Examples of adaptive MCMC." + Journal of Computational and Graphical Statistics 18.2 (2009): 349-367. +""" +mutable struct AdaptiveMvNormal{T1,T2,V} <: Proposal{T1} + d::Int # dimensionality + n::Int # iteration + β::Float64 # constant component mixture weight + σ::Float64 # scale factor for adapted covariance matrix + constant::T1 + adaptive::T2 + Ex::Vector{V} # rolling mean vector + EX::Matrix{V} # scatter matrix of previous draws +end + +function AdaptiveMvNormal(dist::MvNormal; σ=2.38, β=0.05) + n = length(dist) + adaptive = MvNormal(cov(dist)) + AdaptiveMvNormal(n, -1, β, σ, dist, adaptive, zeros(n), zeros(n,n)) +end + +is_symmetric_proposal(::AdaptiveMvNormal) = true + +""" + adapt!(p::AdaptiveMvNormal, x::AbstractVector) + +Adaptation for the adaptive multivariate normal mixture proposal as described +in Haario et al. (2001) and Roberts & Rosenthal (2009). Will perform an online +estimation of the target covariance matrix and mean. The code for this routine +is largely based on `Mamba.jl`. +""" +function adapt!(p::AdaptiveMvNormal, x::AbstractVector) + p.n += 1 + # adapt mean vector and scatter matrix + f = p.n / (p.n + 1) + p.Ex = f * p.Ex + (1 - f) * x + p.EX = f * p.EX + (1 - f) * x * x' + # compute adapted covariance matrix + Σ = (p.σ^2 / (p.d * f)) * (p.EX - p.Ex * p.Ex') + F = cholesky(Hermitian(Σ), check=false) + if rank(F.L) == p.d + p.adaptive = MvNormal(PDMat(Σ, F)) + end +end + +function Base.rand(rng::Random.AbstractRNG, p::AdaptiveMvNormal) + return if p.n > 2 * p.d + p.β * rand(rng, p.constant) + (1 - p.β) * rand(rng, p.adaptive) + else + rand(rng, p.constant) + end +end + +function propose(rng::Random.AbstractRNG, proposal::AdaptiveMvNormal, m::DensityModel) + return rand(rng, proposal) +end + +function propose( + rng::Random.AbstractRNG, + proposal::AdaptiveMvNormal, + model::DensityModel, + t +) + adapt!(proposal, t) + return t + rand(rng, proposal) +end diff --git a/src/mh-core.jl b/src/mh-core.jl index 261dd9f..dd32f85 100644 --- a/src/mh-core.jl +++ b/src/mh-core.jl @@ -240,8 +240,11 @@ function AbstractMCMC.step( # Decide whether to return the previous params or the new one. if -Random.randexp(rng) < logα + accepted!(spl.proposal) return params, params else return params_prev, params_prev end end + +accepted!(::Proposal) = nothing diff --git a/src/proposal.jl b/src/proposal.jl index 286ad94..86d9462 100644 --- a/src/proposal.jl +++ b/src/proposal.jl @@ -103,4 +103,4 @@ function q( t_cond ) return q(proposal(t_cond), t, t_cond) -end \ No newline at end of file +end diff --git a/test/emcee.jl b/test/emcee.jl index 0803ed3..039e9af 100644 --- a/test/emcee.jl +++ b/test/emcee.jl @@ -18,7 +18,7 @@ Random.seed!(100) sampler = Ensemble(1_000, StretchProposal([InverseGamma(2, 3), Normal(0, 1)])) chain = sample(model, sampler, 1_000; - param_names = ["s", "m"], chain_type = Chains) + param_names = ["s", "m"], chain_type = Chains, progress = false) @test mean(chain["s"]) ≈ 49/24 atol=0.1 @test mean(chain["m"]) ≈ 7/6 atol=0.1 @@ -43,7 +43,7 @@ Random.seed!(100) sampler = Ensemble(1_000, StretchProposal(MvNormal(2, 1))) chain = sample(model, sampler, 1_000; - param_names = ["logs", "m"], chain_type = Chains) + param_names = ["logs", "m"], chain_type = Chains, progress = false) @test mean(exp, chain["logs"]) ≈ 49/24 atol=0.1 @test mean(chain["m"]) ≈ 7/6 atol=0.1 diff --git a/test/runtests.jl b/test/runtests.jl index 49d0ce9..82785b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,7 +20,11 @@ include("util.jl") # Define the components of a basic model. insupport(θ) = θ[2] >= 0 dist(θ) = Normal(θ[1], θ[2]) - density(θ) = insupport(θ) ? sum(logpdf.(dist(θ), data)) : -Inf + + # using `let` prevents surprises when data is redefined in some testset + density = let data = data + θ -> insupport(θ) ? sum(logpdf.(dist(θ), data)) : -Inf + end # Construct a DensityModel. model = DensityModel(density) @@ -31,8 +35,9 @@ include("util.jl") spl2 = StaticMH(MvNormal([0.0, 0.0], 1)) # Sample from the posterior. - chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"]) - chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"]) + kwargs = (progress=false, chain_type=StructArray, param_names=["μ", "σ"]) + chain1 = sample(model, spl1, 100000; kwargs...) + chain2 = sample(model, spl2, 100000; kwargs...) # chn_mean ≈ dist_mean atol=atol_v @test mean(chain1.μ) ≈ 0.0 atol=0.1 @@ -47,8 +52,9 @@ include("util.jl") spl2 = RWMH(MvNormal([0.0, 0.0], 1)) # Sample from the posterior. - chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"]) - chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"]) + kwargs = (progress=false, chain_type=StructArray, param_names=["μ", "σ"]) + chain1 = sample(model, spl1, 100000; kwargs...) + chain2 = sample(model, spl2, 100000; kwargs...) # chn_mean ≈ dist_mean atol=atol_v @test mean(chain1.μ) ≈ 0.0 atol=0.1 @@ -56,18 +62,69 @@ include("util.jl") @test mean(chain2.μ) ≈ 0.0 atol=0.1 @test mean(chain2.σ) ≈ 1.0 atol=0.1 end + + @testset "Adaptive random walk" begin + # Set up our sampler with initial parameters. + p1 = [AdaptiveProposal(Normal(0,.4)), AdaptiveProposal(Normal(0,1.2))] + p2 = (μ=AdaptiveProposal(Normal(0,1.4)), σ=AdaptiveProposal(Normal(0,0.2))) + spl1 = MetropolisHastings(p1) + spl2 = MetropolisHastings(p2) + + # Sample from the posterior. + kwargs = (progress=false, chain_type=StructArray, param_names=["μ", "σ"]) + chain1 = sample(model, spl1, 100000; kwargs...) + chain2 = sample(model, spl2, 100000; kwargs...) + + # chn_mean ≈ dist_mean atol=atol_v + @test mean(chain1.μ) ≈ 0.0 atol=0.1 + @test mean(chain1.σ) ≈ 1.0 atol=0.1 + @test mean(chain2.μ) ≈ 0.0 atol=0.1 + @test mean(chain2.σ) ≈ 1.0 atol=0.1 + end + + @testset "Compare adaptive to simple random walk" begin + data = rand(Normal(2., 1.), 500) + m1 = DensityModel(x -> loglikelihood(Normal(x,1), data)) + p1 = RandomWalkProposal(Normal()) + p2 = AdaptiveProposal(Normal()) + kwargs = (progress=false, chain_type=Chains) + c1 = sample(m1, MetropolisHastings(p1), 10000; kwargs...) + c2 = sample(m1, MetropolisHastings(p2), 10000; kwargs...) + @test ess(c2).nt.ess > ess(c1).nt.ess + end + + @testset "Adaptive MvNormal mixture" begin + p1 = AdaptiveMvNormal(MvNormal(2, 0.1)) + spl1 = MetropolisHastings(p1) + kwargs = (progress=false, chain_type=StructArray, param_names=["μ", "σ"]) + chain1 = sample(model, spl1, 100000; kwargs...) + @test mean(chain1.μ) ≈ 0.0 atol=0.1 + @test mean(chain1.σ) ≈ 1.0 atol=0.1 + end + + @testset "Adaptive MvNormal mixture ESS" begin + d = 25 + M = rand(MvNormal(d,1), d) + Σ = M*M' + m = DensityModel(x -> logpdf(MvNormal(Σ), x)) + p = AdaptiveMvNormal(MvNormal(d, 1.)) + kwargs = (progress=false, chain_type=Chains) + c1 = sample(m, MetropolisHastings(p), 10000; kwargs...) + display(p.adaptive.Σ) + c2 = sample(m, RWMH(MvNormal(zeros(d), 1)), 10000; kwargs...) + @test sum(ess(c1).nt.ess .> ess(c2).nt.ess) == 25 + end @testset "parallel sampling" begin spl1 = StaticMH([Normal(0,1), Normal(0, 1)]) - chain1 = sample(model, spl1, MCMCDistributed(), 10000, 4; - param_names=["μ", "σ"], chain_type=Chains) + kwargs = (progress=false, chain_type=Chains, param_names=["μ", "σ"]) + chain1 = sample(model, spl1, MCMCDistributed(), 10000, 4; kwargs...) @test mean(chain1["μ"]) ≈ 0.0 atol=0.1 @test mean(chain1["σ"]) ≈ 1.0 atol=0.1 if VERSION >= v"1.3" - chain2 = sample(model, spl1, MCMCThreads(), 10000, 4; - param_names=["μ", "σ"], chain_type=Chains) + chain2 = sample(model, spl1, MCMCThreads(), 10000, 4; kwargs...) @test mean(chain2["μ"]) ≈ 0.0 atol=0.1 @test mean(chain2["σ"]) ≈ 1.0 atol=0.1 end @@ -84,10 +141,11 @@ include("util.jl") p3 = (a=StaticProposal(Normal(0,1)), b=StaticProposal(InverseGamma(2,3))) p4 = StaticProposal((x=1.0) -> Normal(x, 1)) - c1 = sample(m1, MetropolisHastings(p1), 100; chain_type=Vector{NamedTuple}) - c2 = sample(m2, MetropolisHastings(p2), 100; chain_type=Vector{NamedTuple}) - c3 = sample(m3, MetropolisHastings(p3), 100; chain_type=Vector{NamedTuple}) - c4 = sample(m4, MetropolisHastings(p4), 100; chain_type=Vector{NamedTuple}) + kwargs = (chain_type=Vector{NamedTuple}, progress=false) + c1 = sample(m1, MetropolisHastings(p1), 100; kwargs...) + c2 = sample(m2, MetropolisHastings(p2), 100; kwargs...) + c3 = sample(m3, MetropolisHastings(p3), 100; kwargs...) + c4 = sample(m4, MetropolisHastings(p4), 100; kwargs...) @test keys(c1[1]) == (:param_1, :lp) @test keys(c2[1]) == (:param_1, :param_2, :lp) @@ -102,7 +160,7 @@ include("util.jl") val = [0.4, 1.2] # Sample from the posterior. - chain1 = sample(model, spl1, 10, init_params = val) + chain1 = sample(model, spl1, 10, init_params = val, progress=false) @test chain1[1].params == val end @@ -124,21 +182,22 @@ include("util.jl") @test AdvancedMH.is_symmetric_proposal(p1) # Sample from the posterior with initial parameters. - chain1 = sample(m1, MetropolisHastings(p1), 100000; - chain_type=StructArray, param_names=["x"]) + chain1 = sample(m1, MetropolisHastings(p1), 100000; + progress=false, chain_type=StructArray, param_names=["x"]) @test mean(chain1.x) ≈ mean(d1) atol=0.05 @test std(chain1.x) ≈ std(d1) atol=0.05 end @testset "MALA" begin - + # Set up the sampler. sigma = 1e-1 spl1 = MALA(x -> MvNormal((sigma^2 / 2) .* x, sigma)) # Sample from the posterior with initial parameters. - chain1 = sample(model, spl1, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"]) + chain1 = sample(model, spl1, 100000; progress=false, init_params=ones(2), + chain_type=StructArray, param_names=["μ", "σ"]) @test mean(chain1.μ) ≈ 0.0 atol=0.1 @test mean(chain1.σ) ≈ 1.0 atol=0.1