From fb3f410af6939117a6a1bfa4e102955564082fc2 Mon Sep 17 00:00:00 2001 From: st-- Date: Thu, 17 Mar 2022 14:07:06 +0100 Subject: [PATCH] create ApproximateGPs.TestUtils (#117) * create ApproximateGPs.TestUtils * revert reexporting AbstractGPs --- Project.toml | 5 +- src/ApproximateGPs.jl | 3 + src/TestUtils.jl | 110 +++++++++++++++++++ test/LaplaceApproximationModule.jl | 66 +---------- test/SparseVariationalApproximationModule.jl | 8 +- test/runtests.jl | 23 ++-- 6 files changed, 139 insertions(+), 76 deletions(-) create mode 100644 src/TestUtils.jl diff --git a/Project.toml b/Project.toml index ceb4e997..4e550bf2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ApproximateGPs" uuid = "298c2ebc-0411-48ad-af38-99e88101b606" authors = ["JuliaGaussianProcesses Team"] -version = "0.3.2" +version = "0.3.3" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" @@ -13,11 +13,14 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] AbstractGPs = "0.3, 0.4, 0.5" diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index c8dd643c..6d5e4d95 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -2,6 +2,7 @@ module ApproximateGPs using Reexport +@reexport using AbstractGPs @reexport using GPLikelihoods include("API.jl") @@ -22,4 +23,6 @@ include("LaplaceApproximationModule.jl") include("deprecations.jl") +include("TestUtils.jl") + end diff --git a/src/TestUtils.jl b/src/TestUtils.jl new file mode 100644 index 00000000..7836e974 --- /dev/null +++ b/src/TestUtils.jl @@ -0,0 +1,110 @@ +module TestUtils + +using LinearAlgebra +using Random +using Test + +using Distributions +using LogExpFunctions: logistic, softplus + +using AbstractGPs +using ApproximateGPs + +function generate_data() + X = range(0, 23.5; length=48) + # The random number generator changed in 1.6->1.7. The following vector was generated in Julia 1.6. + # The generating code below is only kept for illustrative purposes. + #! format: off + Y = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] + #! format: on + # Random.seed!(1) + # fs = @. 3 * sin(10 + 0.6X) + sin(0.1X) - 1 + # # invlink = normcdf + # invlink = logistic + # ps = invlink.(fs) + # Y = @. rand(Bernoulli(ps)) + return X, Y +end + +dist_y_given_f(f) = Bernoulli(logistic(f)) + +function build_latent_gp(theta) + variance = softplus(theta[1]) + lengthscale = softplus(theta[2]) + kernel = variance * with_lengthscale(SqExponentialKernel(), lengthscale) + return LatentGP(GP(kernel), dist_y_given_f, 1e-8) +end + +""" + test_approx_lml(approx) + +Test whether in the conjugate case `approx_lml(approx, LatentGP(f, +GaussianLikelihood(), jitter)(x), y)` gives approximately the same answer as +the log marginal likelihood in exact GP regression. + +!!! todo + Not yet implemented. + + Will not necessarily work for approximations that rely on optimization such + as `SparseVariationalApproximation`. + +!!! todo + Also test gradients (for hyperparameter optimization). +""" +function test_approx_lml end + +""" + test_approximation_predictions(approx) + +Test whether the prediction interface for `approx` works and whether in the +conjugate case `posterior(approx, LatentGP(f, GaussianLikelihood(), jitter)(x), y)` +gives approximately the same answer as the exact GP regression posterior. + +!!! note + Should be satisfied by all approximate inference methods, but note that + this does not currently apply for some approximations which rely on + optimization such as `SparseVariationalApproximation`. + +!!! warning + Do not rely on this as the only test of a new approximation! + +See `test_approx_lml`. +""" +function test_approximation_predictions(approx) + rng = MersenneTwister(123456) + N_cond = 5 + N_a = 6 + N_b = 7 + + # Specify prior. + f = GP(Matern32Kernel()) + # Sample from prior. + x = collect(range(-1.0, 1.0; length=N_cond)) + # TODO: Change to x = ColVecs(rand(2, N_cond)) once #109 is fixed + noise_scale = 0.1 + fx = f(x, noise_scale^2) + y = rand(rng, fx) + + jitter = 0.0 # not needed in Gaussian case + lf = LatentGP(f, f -> Normal(f, noise_scale), jitter) + f_approx_post = posterior(approx, lf(x), y) + + @testset "AbstractGPs API" begin + a = collect(range(-1.2, 1.2; length=N_a)) + b = randn(rng, N_b) + AbstractGPs.TestUtils.test_internal_abstractgps_interface(rng, f_approx_post, a, b) + end + + @testset "exact GPR equivalence for Gaussian likelihood" begin + f_exact_post = posterior(f(x, noise_scale^2), y) + xt = vcat(x, randn(rng, 3)) # test at training and new points + + m_approx, c_approx = mean_and_cov(f_approx_post(xt)) + m_exact, c_exact = mean_and_cov(f_exact_post(xt)) + + @test m_approx ≈ m_exact + @test c_approx ≈ c_exact + end +end + +end diff --git a/test/LaplaceApproximationModule.jl b/test/LaplaceApproximationModule.jl index 68548648..a3f705c2 100644 --- a/test/LaplaceApproximationModule.jl +++ b/test/LaplaceApproximationModule.jl @@ -1,28 +1,7 @@ @testset "laplace" begin - function generate_data() - X = range(0, 23.5; length=48) - # The random number generator changed in 1.6->1.7. The following vector was generated in Julia 1.6. - # The generating code below is only kept for illustrative purposes. - #! format: off - Y = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] - #! format: on - # Random.seed!(1) - # fs = @. 3 * sin(10 + 0.6X) + sin(0.1X) - 1 - # # invlink = normcdf - # invlink = logistic - # ps = invlink.(fs) - # Y = [rand(Bernoulli(p)) for p in ps] - return X, Y - end - - dist_y_given_f(f) = Bernoulli(logistic(f)) - - function build_latent_gp(theta) - variance = softplus(theta[1]) - lengthscale = softplus(theta[2]) - kernel = variance * with_lengthscale(SqExponentialKernel(), lengthscale) - return LatentGP(GP(kernel), dist_y_given_f, 1e-8) - end + generate_data = ApproximateGPs.TestUtils.generate_data + dist_y_given_f = ApproximateGPs.TestUtils.dist_y_given_f + build_latent_gp = ApproximateGPs.TestUtils.build_latent_gp function optimize_elbo( build_latent_gp, @@ -49,43 +28,8 @@ end @testset "predictions" begin - rng = MersenneTwister(123456) - N_cond = 5 - N_a = 6 - N_b = 7 - - # Specify prior. - f = GP(Matern32Kernel()) - # Sample from prior. - x = collect(range(-1.0, 1.0; length=N_cond)) - noise_scale = 0.1 - fx = f(x, noise_scale^2) - y = rand(rng, fx) - - jitter = 0.0 # not needed in Gaussian case - lf = LatentGP(f, f -> Normal(f, noise_scale), jitter) - # in Gaussian case, Laplace converges to f_opt in one step; we need the - # second step to compute the cache at f_opt rather than f_init! - f_approx_post = posterior(LaplaceApproximation(; maxiter=2), lf(x), y) - - @testset "AbstractGPs API" begin - a = collect(range(-1.2, 1.2; length=N_a)) - b = randn(rng, N_b) - AbstractGPs.TestUtils.test_internal_abstractgps_interface( - rng, f_approx_post, a, b - ) - end - - @testset "equivalence to exact GPR for Gaussian likelihood" begin - f_exact_post = posterior(f(x, noise_scale^2), y) - xt = vcat(x, randn(rng, 3)) # test at training and new points - - m_approx, c_approx = mean_and_cov(f_approx_post(xt)) - m_exact, c_exact = mean_and_cov(f_exact_post(xt)) - - @test m_approx ≈ m_exact - @test c_approx ≈ c_exact - end + approx = LaplaceApproximation(; maxiter=2) + ApproximateGPs.TestUtils.test_approximation_predictions(approx) end @testset "gradients" begin diff --git a/test/SparseVariationalApproximationModule.jl b/test/SparseVariationalApproximationModule.jl index 104949f2..3deb7f34 100644 --- a/test/SparseVariationalApproximationModule.jl +++ b/test/SparseVariationalApproximationModule.jl @@ -28,7 +28,9 @@ b = randn(rng, N_b) @testset "AbstractGPs interface - Centered" begin - TestUtils.test_internal_abstractgps_interface(rng, f_approx_post_Centered, a, b) + AbstractGPs.TestUtils.test_internal_abstractgps_interface( + rng, f_approx_post_Centered, a, b + ) end @testset "NonCentered" begin @@ -50,7 +52,7 @@ f_approx_post_non_Centered = posterior(approx_non_Centered) @testset "AbstractGPs interface - NonCentered" begin - TestUtils.test_internal_abstractgps_interface( + AbstractGPs.TestUtils.test_internal_abstractgps_interface( rng, f_approx_post_non_Centered, a, b ) end @@ -170,7 +172,7 @@ # Train the SVGP model data = [(x, y)] - opt = ADAM(0.001) + opt = Flux.ADAM(0.001) svgp_ps = Flux.params(svgp_model) diff --git a/test/runtests.jl b/test/runtests.jl index fb7f5ae8..a1ec05f4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,19 +1,20 @@ +using LinearAlgebra using Random using Test -using ApproximateGPs -using Flux -using IterTools -using AbstractGPs -using AbstractGPs: LatentFiniteGP, TestUtils -using Distributions -using LogExpFunctions: logistic -using LinearAlgebra -using PDMats -using Optim -using Zygote + using ChainRulesCore using ChainRulesTestUtils +using Distributions using FiniteDifferences +using Flux: Flux +using IterTools +using LogExpFunctions: softplus +using Optim +using PDMats +using Zygote + +using AbstractGPs +using ApproximateGPs using ApproximateGPs: SparseVariationalApproximationModule, LaplaceApproximationModule # Writing tests: