Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create ApproximateGPs.TestUtils #117

Merged
merged 17 commits into from
Mar 17, 2022
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ApproximateGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["JuliaGaussianProcesses Team"]
version = "0.3.2"
version = "0.3.3"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand All @@ -13,11 +13,14 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
AbstractGPs = "0.3, 0.4, 0.5"
Expand Down
3 changes: 3 additions & 0 deletions src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module ApproximateGPs

using Reexport

@reexport using AbstractGPs
@reexport using GPLikelihoods

include("API.jl")
Expand All @@ -22,4 +23,6 @@ include("LaplaceApproximationModule.jl")

include("deprecations.jl")

include("TestUtils.jl")

end
110 changes: 110 additions & 0 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
module TestUtils

using LinearAlgebra
using Random
using Test

using Distributions
using LogExpFunctions: logistic, softplus

using AbstractGPs
using ApproximateGPs

function generate_data()
X = range(0, 23.5; length=48)
# The random number generator changed in 1.6->1.7. The following vector was generated in Julia 1.6.
# The generating code below is only kept for illustrative purposes.
#! format: off
Y = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
#! format: on
# Random.seed!(1)
# fs = @. 3 * sin(10 + 0.6X) + sin(0.1X) - 1
# # invlink = normcdf
# invlink = logistic
# ps = invlink.(fs)
# Y = @. rand(Bernoulli(ps))
return X, Y
end

dist_y_given_f(f) = Bernoulli(logistic(f))

function build_latent_gp(theta)
variance = softplus(theta[1])
lengthscale = softplus(theta[2])
kernel = variance * with_lengthscale(SqExponentialKernel(), lengthscale)
return LatentGP(GP(kernel), dist_y_given_f, 1e-8)
end

"""
test_approx_lml(approx)

Test whether in the conjugate case `approx_lml(approx, LatentGP(f,
GaussianLikelihood(), jitter)(x), y)` gives approximately the same answer as
the log marginal likelihood in exact GP regression.

!!! todo
Not yet implemented.

Will not necessarily work for approximations that rely on optimization such
as `SparseVariationalApproximation`.

!!! todo
Also test gradients (for hyperparameter optimization).
"""
function test_approx_lml end

"""
test_approximation_predictions(approx)

Test whether the prediction interface for `approx` works and whether in the
conjugate case `posterior(approx, LatentGP(f, GaussianLikelihood(), jitter)(x), y)`
gives approximately the same answer as the exact GP regression posterior.

!!! note
Should be satisfied by all approximate inference methods, but note that
this does not currently apply for some approximations which rely on
optimization such as `SparseVariationalApproximation`.

!!! warning
Do not rely on this as the only test of a new approximation!

See `test_approx_lml`.
"""
function test_approximation_predictions(approx)
rng = MersenneTwister(123456)
N_cond = 5
N_a = 6
N_b = 7

# Specify prior.
f = GP(Matern32Kernel())
# Sample from prior.
x = collect(range(-1.0, 1.0; length=N_cond))
# TODO: Change to x = ColVecs(rand(2, N_cond)) once #109 is fixed
noise_scale = 0.1
fx = f(x, noise_scale^2)
y = rand(rng, fx)

jitter = 0.0 # not needed in Gaussian case
lf = LatentGP(f, f -> Normal(f, noise_scale), jitter)
f_approx_post = posterior(approx, lf(x), y)

@testset "AbstractGPs API" begin
a = collect(range(-1.2, 1.2; length=N_a))
b = randn(rng, N_b)
AbstractGPs.TestUtils.test_internal_abstractgps_interface(rng, f_approx_post, a, b)
end

@testset "exact GPR equivalence for Gaussian likelihood" begin
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
f_exact_post = posterior(f(x, noise_scale^2), y)
xt = vcat(x, randn(rng, 3)) # test at training and new points

m_approx, c_approx = mean_and_cov(f_approx_post(xt))
m_exact, c_exact = mean_and_cov(f_exact_post(xt))

@test m_approx ≈ m_exact
@test c_approx ≈ c_exact
end
end

end
66 changes: 5 additions & 61 deletions test/LaplaceApproximationModule.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,7 @@
@testset "laplace" begin
function generate_data()
X = range(0, 23.5; length=48)
# The random number generator changed in 1.6->1.7. The following vector was generated in Julia 1.6.
# The generating code below is only kept for illustrative purposes.
#! format: off
Y = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
#! format: on
# Random.seed!(1)
# fs = @. 3 * sin(10 + 0.6X) + sin(0.1X) - 1
# # invlink = normcdf
# invlink = logistic
# ps = invlink.(fs)
# Y = [rand(Bernoulli(p)) for p in ps]
return X, Y
end

dist_y_given_f(f) = Bernoulli(logistic(f))

function build_latent_gp(theta)
variance = softplus(theta[1])
lengthscale = softplus(theta[2])
kernel = variance * with_lengthscale(SqExponentialKernel(), lengthscale)
return LatentGP(GP(kernel), dist_y_given_f, 1e-8)
end
generate_data = ApproximateGPs.TestUtils.generate_data
dist_y_given_f = ApproximateGPs.TestUtils.dist_y_given_f
build_latent_gp = ApproximateGPs.TestUtils.build_latent_gp

function optimize_elbo(
build_latent_gp,
Expand All @@ -49,43 +28,8 @@
end

@testset "predictions" begin
rng = MersenneTwister(123456)
N_cond = 5
N_a = 6
N_b = 7

# Specify prior.
f = GP(Matern32Kernel())
# Sample from prior.
x = collect(range(-1.0, 1.0; length=N_cond))
noise_scale = 0.1
fx = f(x, noise_scale^2)
y = rand(rng, fx)

jitter = 0.0 # not needed in Gaussian case
lf = LatentGP(f, f -> Normal(f, noise_scale), jitter)
# in Gaussian case, Laplace converges to f_opt in one step; we need the
# second step to compute the cache at f_opt rather than f_init!
f_approx_post = posterior(LaplaceApproximation(; maxiter=2), lf(x), y)

@testset "AbstractGPs API" begin
a = collect(range(-1.2, 1.2; length=N_a))
b = randn(rng, N_b)
AbstractGPs.TestUtils.test_internal_abstractgps_interface(
rng, f_approx_post, a, b
)
end

@testset "equivalence to exact GPR for Gaussian likelihood" begin
f_exact_post = posterior(f(x, noise_scale^2), y)
xt = vcat(x, randn(rng, 3)) # test at training and new points

m_approx, c_approx = mean_and_cov(f_approx_post(xt))
m_exact, c_exact = mean_and_cov(f_exact_post(xt))

@test m_approx ≈ m_exact
@test c_approx ≈ c_exact
end
approx = LaplaceApproximation(; maxiter=2)
ApproximateGPs.TestUtils.test_approximation_predictions(approx)
end

@testset "gradients" begin
Expand Down
8 changes: 5 additions & 3 deletions test/SparseVariationalApproximationModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
b = randn(rng, N_b)

@testset "AbstractGPs interface - Centered" begin
TestUtils.test_internal_abstractgps_interface(rng, f_approx_post_Centered, a, b)
AbstractGPs.TestUtils.test_internal_abstractgps_interface(
rng, f_approx_post_Centered, a, b
)
end

@testset "NonCentered" begin
Expand All @@ -50,7 +52,7 @@
f_approx_post_non_Centered = posterior(approx_non_Centered)

@testset "AbstractGPs interface - NonCentered" begin
TestUtils.test_internal_abstractgps_interface(
AbstractGPs.TestUtils.test_internal_abstractgps_interface(
rng, f_approx_post_non_Centered, a, b
)
end
Expand Down Expand Up @@ -170,7 +172,7 @@

# Train the SVGP model
data = [(x, y)]
opt = ADAM(0.001)
opt = Flux.ADAM(0.001)

svgp_ps = Flux.params(svgp_model)

Expand Down
23 changes: 12 additions & 11 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
using LinearAlgebra
using Random
using Test
using ApproximateGPs
using Flux
using IterTools
using AbstractGPs
using AbstractGPs: LatentFiniteGP, TestUtils
using Distributions
using LogExpFunctions: logistic
using LinearAlgebra
using PDMats
using Optim
using Zygote

using ChainRulesCore
using ChainRulesTestUtils
using Distributions
using FiniteDifferences
using Flux: Flux
using IterTools
using LogExpFunctions: softplus
using Optim
using PDMats
using Zygote

using AbstractGPs
using ApproximateGPs
using ApproximateGPs: SparseVariationalApproximationModule, LaplaceApproximationModule

# Writing tests:
Expand Down