diff --git a/Project.toml b/Project.toml index c763a8824..278444f08 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 0cb80e5ce..89f69e6d1 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -23,7 +23,7 @@ kernels = Dict( inputtypes = Dict("ColVecs" => (Xc, Yc), "RowVecs" => (Xr, Yr), "Vecs" => (Xv, Yv)) functions = Dict( - "kernelmatrixX" => (kernel, X, Y) -> kernelmatrix(kernel, X), + "kernelmatrixX" => (kernel, X, Y) -> invoke(kernelmatrix, Tuple{kernel, Any}, kernel, X), "kernelmatrixXY" => (kernel, X, Y) -> kernelmatrix(kernel, X, Y), "kernelmatrix_diagX" => (kernel, X, Y) -> kernelmatrix_diag(kernel, X), "kernelmatrix_diagXY" => (kernel, X, Y) -> kernelmatrix_diag(kernel, X, Y), @@ -41,6 +41,14 @@ end # Uncomment the following to run benchmark locally -# tune!(SUITE) +tune!(SUITE) -# results = run(SUITE, verbose=true) +results = run(SUITE, verbose=true) + +Xc = ColVecs(rand(2, 2000)) +k = SqExponentialKernel() + +@which kernelmatrix(k, Xc) +@btime kernelmatrix($k, $Xc); +@btime invoke(kernelmatrix, Tuple{KernelFunctions.SimpleKernel, $typeof(Xc)}, $k, $Xc); +# @btime invoke(kernelmatrix, Tuple{Kernel, $typeof(Xc)}, $k, $Xc); \ No newline at end of file diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 63205b5bf..a291fbc59 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -43,6 +43,8 @@ export MOInput, prepare_isotopic_multi_output_data, prepare_heterotopic_multi_ou export IndependentMOKernel, LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel +export DiffPt + # Reexports export tensor, ⊗, compose @@ -125,6 +127,7 @@ include("chainrules.jl") include("zygoterules.jl") include("TestUtils.jl") +include("diffKernel.jl") function __init__() @require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin diff --git a/src/diffKernel.jl b/src/diffKernel.jl new file mode 100644 index 000000000..8ed0a8b88 --- /dev/null +++ b/src/diffKernel.jl @@ -0,0 +1,101 @@ +import ForwardDiff as FD +import LinearAlgebra as LA + +""" + DiffPt(x; partial=()) + +For a covariance kernel k of GP Z, i.e. +```julia + k(x,y) # = Cov(Z(x), Z(y)), +``` +a DiffPt allows the differentiation of Z, i.e. +```julia + k(DiffPt(x, partial=1), y) # = Cov(∂₁Z(x), Z(y)) +``` +for higher order derivatives partial can be any iterable, i.e. +```julia + k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y)) +``` +""" + +IndexType = Union{Int,Base.AbstractCartesianIndex} + +struct DiffPt{Order,KeyT<:IndexType,T} + pos::T # the actual position + partials::NTuple{Order,KeyT} +end + +DiffPt(x::T) where {T<:AbstractArray} = DiffPt{0,keytype(T),T}(x, ()::NTuple{0,keytype(T)}) +DiffPt(x::T) where {T<:Number} = DiffPt{0,Int,T}(x, ()::NTuple{0,Int}) +DiffPt(x::T, partial::IndexType) where {T} = DiffPt{1,IndexType,T}(x, (partial,)) +function DiffPt(x::T, partials::NTuple{Order,KeyT}) where {T,Order,KeyT} + return DiffPt{Order,KeyT,T}(x, partials) +end + +""" + tangentCurve(x₀, i::IndexType) +returns the function (t ↦ x₀ + teᵢ) where eᵢ is the unit vector at index i +""" +function tangentCurve(x0::AbstractArray, idx::IndexType) + return t -> begin + x = similar(x0, promote_type(eltype(x0), typeof(t))) + copyto!(x, x0) + x[idx] += t + return x + end +end +function tangentCurve(x0::Number, ::IndexType) + return t -> x0 + t +end + +partial(func) = func +function partial(func, idx::IndexType) + return x -> FD.derivative(func ∘ tangentCurve(x, idx), 0) +end +function partial(func, partials::IndexType...) + idx, state = iterate(partials) + return partial( + x -> FD.derivative(func ∘ tangentCurve(x, idx), 0), Base.rest(partials, state)... + ) +end + +""" +Take the partial derivative of a function with two dim-dimensional inputs, +i.e. 2*dim dimensional input +""" +function partial( + k, partials_x::Tuple{Vararg{T}}, partials_y::Tuple{Vararg{T}} +) where {T<:IndexType} + local f(x, y) = partial(t -> k(t, y), partials_x...)(x) + return (x, y) -> partial(t -> f(x, t), partials_y...)(y) +end + +""" + _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel} + +implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since +generics are not allowed in the syntax above by the dispatch system, this +redirection over `_evaluate` is necessary + +unboxes the partial instructions from DiffPt and applies them to k, +evaluates them at the positions of DiffPt +""" +function _evaluate(k::T, x::DiffPt, y::DiffPt) where {T<:Kernel} + return partial(k, x.partials, y.partials)(x.pos, y.pos) +end + +#= +This is a hack to work around the fact that the `where {T<:Kernel}` clause is +not allowed for the `(::T)(x,y)` syntax. If we were to only implement +```julia + (::Kernel)(::DiffPt,::DiffPt) +``` +then julia would not know whether to use +`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)` +``` +=# +for T in [SimpleKernel, Kernel] #subtypes(Kernel) + (k::T)(x::DiffPt, y::DiffPt) = _evaluate(k, x, y) + (k::T)(x::DiffPt, y) = _evaluate(k, x, DiffPt(y)) + (k::T)(x, y::DiffPt) = _evaluate(k, DiffPt(x), y) +end diff --git a/src/mokernels/differentiable.jl b/src/mokernels/differentiable.jl new file mode 100644 index 000000000..1825efa60 --- /dev/null +++ b/src/mokernels/differentiable.jl @@ -0,0 +1,31 @@ + +struct Partial{Order} + indices::CartesianIndex{Order} +end + +function Partial(indices::Integer...) + return Partial{length(indices)}(CartesianIndex(indices)) +end + +compact_string_representation(::Partial{0}) = print(io, "id") +function compact_string_representation(p::Partial) + tuple = Tuple(p.indices) + lower_numbers = @. (n -> '₀' + n)(reverse(digits(tuple))) + return join(["∂$(join(x))" for x in lower_numbers]) +end +function Base.show(io::IO, ::MIME"text/plain", p::Partial) + if get(io, :compact, false) + print(io, "Partial($(Tuple(p.indices)))") + else + print(io, compact_string_representation(p)) + end +end + +function Base.show(io::IO, ::MIME"text/html", p::Partial) + tuple = Tuple(p.indices) + if get(io, :compact, false) + print(io, join(map(n -> "∂$(n)", tuple), "")) + else + print(io, compact_string_representation(p)) + end +end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index e16f39a6c..5b21d7b70 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/test/diffKernel.jl b/test/diffKernel.jl new file mode 100644 index 000000000..3a5e3c7e1 --- /dev/null +++ b/test/diffKernel.jl @@ -0,0 +1,25 @@ +@testset "diffKernel" begin + @testset "smoke test" begin + k = MaternKernel() + k(1, 1) + k(1, DiffPt(1, (1, 1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1 + k(DiffPt([1], 1), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2] + k(DiffPt([1, 2], 1), DiffPt([1, 2], 2))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2] + end + + @testset "Sanity Checks with $k" for k in [SEKernel()] + for x in [0, 1, -1, 42] + # for stationary kernels Cov(∂Z(x) , Z(x)) = 0 + @test k(DiffPt(x, 1), x) ≈ 0 + + # the slope should be positively correlated with a point further down + @test k( + DiffPt(x, 1), # slope + x + 1e-1, # point further down + ) > 0 + + # correlation with self should be positive + @test k(DiffPt(x, 1), DiffPt(x, 1)) > 0 + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index caf43cb91..accdea1f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -176,6 +176,7 @@ include("test_utils.jl") include("generic.jl") include("chainrules.jl") include("zygoterules.jl") + include("diffKernel.jl") @testset "doctests" begin DocMeta.setdocmeta!(