Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
*.jl.mem
/docs/build/
.vscode
Manifest.toml
test/Manifest.toml
Manifest.toml
70 changes: 68 additions & 2 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ArraysOfArrays
using Distributions
using GPLikelihoods: AbstractLikelihood
using LinearAlgebra
using LogExpFunctions: logsumexp
using MeasureBase: MeasureBase, logdensity_def, marginals
using MeasureTheory: For
using SplitApplyCombine: invert
Expand All @@ -20,7 +21,6 @@ remove_ntdist_wrapper(d) = d

can_split(::AbstractLikelihood) = true


function flatten_params(φ, n_l)
if n_l == 1
return vcat(values(φ)...)
Expand Down Expand Up @@ -65,6 +65,22 @@ function test_auglik(
ft = invert_if_array_of_arrays(f)
qft = invert_if_array_of_arrays(qf)
nf = nlatent(lik)
@testset "Marginalizing resolve to original pdf" begin
# We draw S samples from the auxiliary variable and checks that the marginal
# resolve to the true pdf.
orig_lik = lik(f)
orig_logpdf = logpdf(orig_lik, y)
S = 1_000_000
aux_dist = aux_prior(lik, y)
aug_logpdf = logsumexp(
map(1:S) do Ω
Ω = tvrand(rng, aux_dist)
logtilt(lik, Ω, y, f) - log(S)
end,
)
@test orig_logpdf ≈ aug_logpdf atol = 1e-1 # This is high cause we estimate the thing
# not in log-space
end
# Testing sampling
@testset "Sampling" begin
Ω = init_aux_variables(lik, n)
Expand Down Expand Up @@ -130,7 +146,12 @@ function test_auglik(
logC₂ = logtilt(lik, Ω, y, f₂) + logpdf(pF, f₂) - logpdf(qF, f₂)
@test logC₁ ≈ logC₂ atol = 1e-5
else
S = inv.(Symmetric.(Ref(inv(K)) .+ Diagonal.(auglik_precision(lik, Ω, y, ft))))
S =
inv.(
Symmetric.(
Ref(inv(K)) .+ Diagonal.(auglik_precision(lik, Ω, y, ft))
)
)
m = S .* auglik_potential(lik, Ω, y, ft)
qF = MvNormal.(m, S)
pF = MvNormal(K)
Expand Down Expand Up @@ -170,6 +191,51 @@ function test_auglik(

@test all(x -> all(>=(0), x), γs) # Check that the variance is positive

@testset "expected_logtilt" begin
@test expected_logtilt(lik, qΩ, y, qf) isa Real
S = 1_000_000
val = expected_logtilt(lik, qΩ, y, qf)
samp_val =
mapreduce(
+, (tvrand(rng, qΩ) for _ in 1:S), (rand.(rng, qf) for _ in 1:S)
) do Ω, f
logtilt(lik, Ω, y, f)
end / S
@test val ≈ samp_val atol = 1e-2 # This is still pretty high
end

# Check that the obtained auxiliary posterior is indeed a maximum.
@testset "aux_posterior" begin
φ = TupleVectors.unwrap(aux_posterior(lik, y, qf).pars) # TupleVector
φ_opt = vcat(values(φ)...)
s = keys(φ)
n_var = length(s)
function loss(φ)
q = For(
qΩ.f,
TupleVector(
NamedTuple{s}(
collect(φ[((j - 1) * n_var + 1):(j * n_var)] for j in 1:n_var)
),
),
)
return -expected_logtilt(lik, q, y, qf) + aux_kldivergence(lik, q, y)
end
ϵ = 1e-2
# Test that by perturbing the value in random directions, the loss does not decrease
for i in n_var * n
(lik isa PoissonLikelihood && i <= n) && continue # We do not want to vary y and Integer parameters.
Δ = zeros(n_var * n)
Δ[i] = ϵ # We try one element at a time
@test loss(φ_opt) <= loss(φ_opt + Δ)
@test loss(φ_opt) <= loss(φ_opt - Δ)
end
end
pΩ = aux_prior(lik, y)
@test pΩ isa ProductMeasure
@test kldivergence(first(marginals(qΩ)), first(marginals(pΩ))) isa Real
@test aux_kldivergence(lik, qΩ, pΩ) isa Real

# TODO test that aux_posterior parameters return the minimizing
φ = TupleVectors.unwrap(only(aux_posterior(lik, y, qft).inds)) # TupleVector
φ_opt = flatten_params(φ, nlatent(lik))
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
Expand Down