Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LazyKernelMatrix and lazykernelmatrix #515

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,22 @@ CurrentModule = KernelFunctions

## Functions

The KernelFunctions API comprises the following four functions.
The KernelFunctions API comprises the following functions.

The first set eagerly construct all or part of a kernel matrix
```@docs
kernelmatrix
kernelmatrix!
kernelmatrix_diag
kernelmatrix_diag!
```

It is also possible to lazily construct the same matrix, which is recommended when the kernel matrix might be too large to store in memory
```@docs
lazykernelmatrix
LazyKernelMatrix
```

## Input Types

The above API operates on collections of inputs.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/create_kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Finally there are additional functions you can define to bring in more features:
- `KernelFunctions.iskroncompatible(k::MyKernel)`: if your kernel factorizes in dimensions, you can declare your kernel as `iskroncompatible(k) = true` to use Kronecker methods.
- `KernelFunctions.dim(x::MyDataType)`: by default the dimension of the inputs will only be checked for vectors of type `AbstractVector{<:Real}`. If you want to check the dimensionality of your inputs, dispatch the `dim` function on your datatype. Note that `0` is the default.
- `dim` is called within `KernelFunctions.validate_inputs(x::MyDataType, y::MyDataType)`, which can instead be directly overloaded if you want to run special checks for your input types.
- `kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` functions to eventually optimize the computations.
- `kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` and `lazykernelmatrix` functions to eventually optimize the computations.
- `Base.print(io::IO, k::MyKernel)`: if you want to specialize the printing of your kernel.

KernelFunctions uses [Functors.jl](https://github.com/FluxML/Functors.jl) for specifying trainable kernel parameters
Expand Down
9 changes: 8 additions & 1 deletion docs/src/userguide.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ k(x1, x2)

## Creating a Kernel Matrix

Kernel matrices can be created via the `kernelmatrix` function or `kernelmatrix_diag` for only the diagonal.
Kernel matrices can be eagerly created via the `kernelmatrix` function or `kernelmatrix_diag` for only the diagonal.
For example, for a collection of 10 `Real`-valued inputs:
```julia
k = SqExponentialKernel()
Expand Down Expand Up @@ -90,6 +90,13 @@ kernelmatrix(k, X; obsdim=2) # same as ColVecs(X)
```
This is similar to the convention used in [Distances.jl](https://github.com/JuliaStats/Distances.jl).

When data is large, it may not be possible to store the kernel matrix in memory.
Then it is recommended to use `lazykernelmatrix`:
```julia
lazykernelmatrix(k, RowVecs(X))
lazykernelmatrix(k, ColVecs(X))
```

### So what type should I use to represent a collection of inputs?
The central assumption made by KernelFunctions.jl is that all collections of `N` inputs are represented by `AbstractVector`s of length `N`.
Abstraction is then used to ensure that efficiency is retained, `ColVecs` and `RowVecs`
Expand Down
2 changes: 2 additions & 0 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module KernelFunctions

export kernelmatrix, kernelmatrix!, kernelmatrix_diag, kernelmatrix_diag!
export LazyKernelMatrix, lazykernelmatrix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to export both? Is lazykernelmatrix sufficient maybe?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably.

export duplicate, set! # Helpers

export Kernel, MOKernel
Expand Down Expand Up @@ -106,6 +107,7 @@ include("kernels/gibbskernel.jl")
include("kernels/scaledkernel.jl")
include("kernels/normalizedkernel.jl")
include("matrix/kernelmatrix.jl")
include("matrix/lazykernelmatrix.jl")
include("kernels/kernelsum.jl")
include("kernels/kernelproduct.jl")
include("kernels/kerneltensorproduct.jl")
Expand Down
4 changes: 4 additions & 0 deletions src/matrix/kernelmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@ Compute the kernel `κ` for each pair of inputs in `x`.
Returns a matrix of size `(length(x), length(x))` satisfying
`kernelmatrix(κ, x)[p, q] == κ(x[p], x[q])`.

If `x` is large, consider using [`lazykernelmatrix`](@ref) instead.

kernelmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector)

Compute the kernel `κ` for each pair of inputs in `x` and `y`.
Returns a matrix of size `(length(x), length(y))` satisfying
`kernelmatrix(κ, x, y)[p, q] == κ(x[p], y[q])`.

If `x` and `y` are large, consider using [`lazykernelmatrix`](@ref) instead.

kernelmatrix(κ::Kernel, X::AbstractMatrix; obsdim)
kernelmatrix(κ::Kernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim)

Expand Down
109 changes: 109 additions & 0 deletions src/matrix/lazykernelmatrix.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
lazykernelmatrix(κ::Kernel, x::AbstractVector) -> AbstractMatrix

Construct a lazy representation of the kernel `κ` for each pair of inputs in `x`.

The result is a matrix with the same entries as [`kernelmatrix(κ, x)`](@ref) but where the
entries are not computed until they are needed.
"""
lazykernelmatrix(κ::Kernel, x) = lazykernelmatrix(κ, x, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to optimize this for the symmetric case, IMO, similar to kernelmatrix (which IIRC often does not use such a fallback but more optimized methods).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The optimization for the symmetric case is only calculating half the matrix lets say the top half and then redirecting all queries from the bottom half to the top half. (Actually distances simply copies the top half into the bottom half).

Since this is lazy, there is probably no point in this optimization because you do not do the calculation from the start. And when you call getindex it does not matter whether you calculate the element in the top or bottom half.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least with other lazy iterators in Base it's a common pattern to collect results at some point (e.g., after filtering, mapping, etc.). In this case it seems beneficial to know that the lazy matrix is symmetric.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll see if there are ops we can optimize without too much extra code complexity.

Copy link

@FelixBenning FelixBenning Jun 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an interface for that? I mean you could use LinearAlgebra.Symmetric since that just wraps the original matrix afaik, but it also simply redirects queries so collect would still cause two calculations since you do a calculation per query.

I mean you could just specialize collect I guess


"""
lazykernelmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector) -> AbstractMatrix

Construct a lazy representation of the kernel `κ` for each pair of inputs in `x`.

The result is a matrix with the same entries as [`kernelmatrix(κ, x, y)`](@ref) but where
the entries are not computed until they are needed.
"""
lazykernelmatrix(κ::Kernel, x, y) = LazyKernelMatrix(κ, x, y)

"""
LazyKernelMatrix(κ::Kernel, x[, y])
LazyKernelMatrix{T<:Real}(κ::Kernel, x, y)

Construct a lazy representation of the kernel `κ` for each pair of inputs in `x` and `y`.

Instead of constructing this directly, it is better to call
[`lazykernelmatrix(κ, x[, y])`](@ref lazykernelmatrix).
"""
struct LazyKernelMatrix{T<:Real,Tk<:Kernel,Tx<:AbstractVector,Ty<:AbstractVector} <:
AbstractMatrix{T}
kernel::Tk
x::Tx
y::Ty
function LazyKernelMatrix{T}(κ::Tk, x::Tx, y::Ty) where {T<:Real,Tk<:Kernel,Tx,Ty}
Base.require_one_based_indexing(x)
Base.require_one_based_indexing(y)
return new{T,Tk,Tx,Ty}(κ, x, y)
end
function LazyKernelMatrix{T}(κ::Tk, x::Tx) where {T<:Real,Tk<:Kernel,Tx}
Base.require_one_based_indexing(x)
return new{T,Tk,Tx,Tx}(κ, x, x)
end
end
function LazyKernelMatrix(κ::Kernel, x::AbstractVector, y::AbstractVector)
# evaluate once to get eltype
T = typeof(κ(first(x), first(y)))
return LazyKernelMatrix{T}(κ, x, y)
end
LazyKernelMatrix(κ::Kernel, x::AbstractVector) = LazyKernelMatrix(κ, x, x)

Base.Matrix(K::LazyKernelMatrix) = kernelmatrix(K.kernel, K.x, K.y)
function Base.AbstractMatrix{T}(K::LazyKernelMatrix) where {T}
return LazyKernelMatrix{T}(K.kernel, K.x, K.y)
end

Base.size(K::LazyKernelMatrix) = (length(K.x), length(K.y))

Base.axes(K::LazyKernelMatrix) = (axes(K.x, 1), axes(K.y, 1))

function Base.getindex(K::LazyKernelMatrix{T}, i::Int, j::Int) where {T}
return T(K.kernel(K.x[i], K.y[j]))
end
for f in (:getindex, :view)
@eval begin
function Base.$f(
K::LazyKernelMatrix{T},
I::Union{Colon,AbstractVector},
J::Union{Colon,AbstractVector},
) where {T}
return LazyKernelMatrix{T}(K.kernel, $f(K.x, I), $f(K.y, J))
end
end
end

Base.zero(K::LazyKernelMatrix{T}) where {T} = LazyKernelMatrix{T}(ZeroKernel(), K.x, K.y)
Base.one(K::LazyKernelMatrix{T}) where {T} = LazyKernelMatrix{T}(WhiteKernel(), K.x, K.y)

function Base.:*(c::S, K::LazyKernelMatrix{T}) where {T,S<:Real}
R = typeof(oneunit(S) * oneunit(T))
return LazyKernelMatrix{R}(c * K.kernel, K.x, K.y)
end
Base.:*(K::LazyKernelMatrix, c::Real) = c * K
Base.:/(K::LazyKernelMatrix, c::Real) = K * inv(c)
Base.:\(c::Real, K::LazyKernelMatrix) = inv(c) * K

function Base.:+(K::LazyKernelMatrix{T}, C::UniformScaling{S}) where {T,S<:Real}
if isequal(K.x, K.y)
R = typeof(zero(T) + zero(S))
return LazyKernelMatrix{R}(K.kernel + C.λ * WhiteKernel(), K.x, K.y)
else
return Matrix(K) + C
end
end
function Base.:+(C::UniformScaling{S}, K::LazyKernelMatrix{T}) where {T,S<:Real}
if isequal(K.x, K.y)
R = typeof(zero(T) + zero(S))
return LazyKernelMatrix{R}(C.λ * WhiteKernel() + K.kernel, K.x, K.y)
else
return C + Matrix(K)
end
end
function Base.:+(K1::LazyKernelMatrix, K2::LazyKernelMatrix)
if isequal(K1.x, K2.x) && isequal(K1.y, K2.y)
return LazyKernelMatrix(K1.kernel + K2.kernel, K1.x, K1.y)
else
return Matrix(K1) + Matrix(K2)
end
end