-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clarify API for GP approximations #361
Changes from 10 commits
12a0066
401a3dd
6af01e2
73e70d9
3d30d0b
f0e4d10
101e351
e50cd91
7d39319
9de34ce
b3bd78d
dfacf2a
6f14e02
806f633
45c2d6e
501b702
f260171
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ export rand!, | |
mean_vector, | ||
marginals, | ||
logpdf, | ||
approx_log_evidence, | ||
elbo, | ||
dtc, | ||
posterior, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,14 @@ struct VFE{Tfz<:FiniteGP} | |
fz::Tfz | ||
end | ||
|
||
const DTC = VFE | ||
""" | ||
DTC(fz::FiniteGP) | ||
|
||
Similar to `VFE`, but uses a different objective for `approx_log_evidence`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could maybe do with a better docstring but then it needs to be sorted out more thoroughly anyways (see #309) and I can't think of what it should be right now, so would leave that for some other PR/person/time... |
||
""" | ||
struct DTC{Tfz<:FiniteGP} | ||
fz::Tfz | ||
end | ||
|
||
struct ApproxPosteriorGP{Tapprox,Tprior,Tdata} <: AbstractGP | ||
approx::Tapprox | ||
|
@@ -48,7 +55,7 @@ true | |
processes". In: Proceedings of the Twelfth International Conference on Artificial | ||
Intelligence and Statistics. 2009. | ||
""" | ||
function posterior(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real}) | ||
function posterior(vfe::Union{VFE,DTC}, fx::FiniteGP, y::AbstractVector{<:Real}) | ||
@assert vfe.fz.f === fx.f | ||
|
||
U_y = _cholesky(_symmetric(fx.Σy)).U | ||
|
@@ -69,7 +76,7 @@ end | |
|
||
""" | ||
function update_posterior( | ||
f_post_approx::ApproxPosteriorGP{<:VFE}, | ||
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}}, | ||
fx::FiniteGP, | ||
y::AbstractVector{<:Real} | ||
) | ||
|
@@ -78,7 +85,9 @@ Update the `ApproxPosteriorGP` given a new set of observations. Here, we retain | |
set of pseudo-points. | ||
""" | ||
function update_posterior( | ||
f_post_approx::ApproxPosteriorGP{<:VFE}, fx::FiniteGP, y::AbstractVector{<:Real} | ||
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}}, | ||
fx::FiniteGP, | ||
y::AbstractVector{<:Real}, | ||
) | ||
@assert f_post_approx.prior === fx.f | ||
|
||
|
@@ -111,14 +120,14 @@ end | |
|
||
""" | ||
function update_posterior( | ||
f_post_approx::ApproxPosteriorGP{<:VFE}, | ||
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}}, | ||
z::FiniteGP, | ||
) | ||
|
||
Update the `ApproxPosteriorGP` given a new set of pseudo-points to append to the existing | ||
set of pseudo-points. | ||
""" | ||
function update_posterior(f_post_approx::ApproxPosteriorGP{<:VFE}, fz::FiniteGP) | ||
function update_posterior(f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}}, fz::FiniteGP) | ||
@assert f_post_approx.prior === fz.f | ||
|
||
z_old = inducing_points(f_post_approx) | ||
|
@@ -161,48 +170,56 @@ function update_posterior(f_post_approx::ApproxPosteriorGP{<:VFE}, fz::FiniteGP) | |
x=f_post_approx.data.x, | ||
Σy=f_post_approx.data.Σy, | ||
) | ||
return ApproxPosteriorGP(VFE(fz_new), f_post_approx.prior, cache) | ||
return ApproxPosteriorGP( | ||
_update_approx(f_post_approx.approx, fz_new), f_post_approx.prior, cache | ||
) | ||
end | ||
|
||
_update_approx(vfe::VFE, fz_new::FiniteGP) = VFE(fz_new) | ||
_update_approx(dtc::DTC, fz_new::FiniteGP) = DTC(fz_new) | ||
Comment on lines
+174
to
+179
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this the right way to handle this? if anyone can think of a better approach please say:) |
||
|
||
# AbstractGP interface implementation. | ||
|
||
function Statistics.mean(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector) | ||
function Statistics.mean(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector) | ||
return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α | ||
end | ||
|
||
function Statistics.cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector) | ||
function Statistics.cov(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector) | ||
A = f.data.U' \ cov(f.prior, inducing_points(f), x) | ||
return cov(f.prior, x) - At_A(A) + Xt_invA_X(f.data.Λ_ε, A) | ||
end | ||
|
||
function Statistics.var(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector) | ||
function Statistics.var(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector) | ||
A = f.data.U' \ cov(f.prior, inducing_points(f), x) | ||
return var(f.prior, x) - diag_At_A(A) + diag_Xt_invA_X(f.data.Λ_ε, A) | ||
end | ||
|
||
function Statistics.cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector, y::AbstractVector) | ||
function Statistics.cov( | ||
f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector, y::AbstractVector | ||
) | ||
A_zx = f.data.U' \ cov(f.prior, inducing_points(f), x) | ||
A_zy = f.data.U' \ cov(f.prior, inducing_points(f), y) | ||
return cov(f.prior, x, y) - A_zx'A_zy + Xt_invA_Y(A_zx, f.data.Λ_ε, A_zy) | ||
end | ||
|
||
function StatsBase.mean_and_cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector) | ||
function StatsBase.mean_and_cov(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector) | ||
A = f.data.U' \ cov(f.prior, inducing_points(f), x) | ||
m_post = mean(f.prior, x) + A' * f.data.m_ε | ||
C_post = cov(f.prior, x) - At_A(A) + Xt_invA_X(f.data.Λ_ε, A) | ||
return m_post, C_post | ||
end | ||
|
||
function StatsBase.mean_and_var(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector) | ||
function StatsBase.mean_and_var(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector) | ||
A = f.data.U' \ cov(f.prior, inducing_points(f), x) | ||
m_post = mean(f.prior, x) + A' * f.data.m_ε | ||
c_post = var(f.prior, x) - diag_At_A(A) + diag_Xt_invA_X(f.data.Λ_ε, A) | ||
return m_post, c_post | ||
end | ||
|
||
inducing_points(f::ApproxPosteriorGP{<:VFE}) = f.approx.fz.x | ||
inducing_points(f::ApproxPosteriorGP{<:Union{VFE,DTC}}) = f.approx.fz.x | ||
|
||
""" | ||
approx_log_evidence(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real}) | ||
elbo(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real}) | ||
|
||
The Titsias Evidence Lower BOund (ELBO) [1]. `y` are observations of `fx`, and `v.z` | ||
|
@@ -228,14 +245,16 @@ true | |
processes". In: Proceedings of the Twelfth International Conference on Artificial | ||
Intelligence and Statistics. 2009. | ||
""" | ||
function elbo(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real}) | ||
function approx_log_evidence(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real}) | ||
@assert vfe.fz.f === fx.f | ||
_dtc, A = _compute_intermediates(fx, y, vfe.fz) | ||
return _dtc - (tr_Cf_invΣy(fx, fx.Σy) - sum(abs2, A)) / 2 | ||
dtc_objective, A = _compute_intermediates(fx, y, vfe.fz) | ||
return dtc_objective - (tr_Cf_invΣy(fx, fx.Σy) - sum(abs2, A)) / 2 | ||
end | ||
|
||
elbo(vfe::VFE, fx, y) = approx_log_evidence(vfe, fx, y) | ||
|
||
""" | ||
dtc(v::VFE, fx::FiniteGP, y::AbstractVector{<:Real}) | ||
approx_log_evidence(dtc::DTC, fx::FiniteGP, y::AbstractVector{<:Real}) | ||
|
||
The Deterministic Training Conditional (DTC) [1]. `y` are observations of `fx`, and `v.z` | ||
are inducing points. | ||
|
@@ -248,25 +267,25 @@ julia> x = randn(1000); | |
|
||
julia> z = range(-5.0, 5.0; length=256); | ||
|
||
julia> v = VFE(f(z)); | ||
julia> d = DTC(f(z)); | ||
|
||
julia> y = rand(f(x, 0.1)); | ||
|
||
julia> isapprox(dtc(v, f(x, 0.1), y), logpdf(f(x, 0.1), y); atol=1e-6, rtol=1e-6) | ||
julia> isapprox(approx_log_evidence(d, f(x, 0.1), y), logpdf(f(x, 0.1), y); atol=1e-6, rtol=1e-6) | ||
true | ||
``` | ||
|
||
[1] - M. Seeger, C. K. I. Williams and N. D. Lawrence. "Fast Forward Selection to Speed Up | ||
Sparse Gaussian Process Regression". In: Proceedings of the Ninth International Workshop on | ||
Artificial Intelligence and Statistics. 2003 | ||
""" | ||
function dtc(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real}) | ||
@assert vfe.fz.f === fx.f | ||
_dtc, _ = _compute_intermediates(fx, y, vfe.fz) | ||
return _dtc | ||
function approx_log_evidence(dtc::DTC, fx::FiniteGP, y::AbstractVector{<:Real}) | ||
@assert dtc.fz.f === fx.f | ||
dtc_objective, _ = _compute_intermediates(fx, y, dtc.fz) | ||
return dtc_objective | ||
end | ||
|
||
# Factor out computations common to the `elbo` and `dtc`. | ||
# Factor out computations of `approx_log_evidence` common to `VFE` and `DTC` | ||
function _compute_intermediates(fx::FiniteGP, y::AbstractVector{<:Real}, fz::FiniteGP) | ||
length(fx) == length(y) || throw( | ||
DimensionMismatch( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have gone the other way around?
But maybe that's too complicated...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find this order of arguments quite intuitive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I meant that the implemented definition would be posterior(::ExactPosterior, gp, y) and that posterior(gp, y) would default to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make any difference in practice? (I kinda like it as it is but that might just be status-quo bias too...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice I think the answer is no.
Within the code, there might be something to be said for consistency. Every
posterior
is defined with the 3-argument form, but the exact one gets a special alias.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be fair, there is only one posterior. Shouldn't VFE and DTC dispatched on approx_posterior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@theogf what do you mean? there is no
approx_posterior
method..There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha somehow in my mind there was a
approx_posterior
method. So yeah then it's back toExactInference
) ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ApproximateGPs has the 3-args posterior though, no wrapping?