Skip to content

use only() instead of first() #403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/basekernels/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ end

@functor ConstantKernel

kappa(κ::ConstantKernel, x::Real) = first(κ.c) * one(x)
kappa(κ::ConstantKernel, x::Real) = only(κ.c) * one(x)

metric(::ConstantKernel) = Delta()

Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", first(κ.c), ")")
Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", only(κ.c), ")")
4 changes: 2 additions & 2 deletions src/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ end

@functor GammaExponentialKernel

kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^first(κ.γ))
kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^only(κ.γ))

metric(k::GammaExponentialKernel) = k.metric

iskroncompatible(::GammaExponentialKernel) = true

function Base.show(io::IO, κ::GammaExponentialKernel)
return print(
io, "Gamma Exponential Kernel (γ = ", first(κ.γ), ", metric = ", κ.metric, ")"
io, "Gamma Exponential Kernel (γ = ", only(κ.γ), ", metric = ", κ.metric, ")"
)
end
6 changes: 3 additions & 3 deletions src/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ function (κ::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
modX = sum(abs2, x)
modY = sum(abs2, y)
modXY = sqeuclidean(x, y)
h = first(κ.h)
h = only(κ.h)
return (modX^h + modY^h - modXY^h) / 2
end

function (κ::FBMKernel)(x::Real, y::Real)
return (abs2(x)^first(κ.h) + abs2(y)^first(κ.h) - abs2(x - y)^first(κ.h)) / 2
return (abs2(x)^only(κ.h) + abs2(y)^only(κ.h) - abs2(x - y)^only(κ.h)) / 2
end

function Base.show(io::IO, κ::FBMKernel)
return print(io, "Fractional Brownian Motion Kernel (h = ", first(κ.h), ")")
return print(io, "Fractional Brownian Motion Kernel (h = ", only(κ.h), ")")
end

_fbm(modX, modY, modXY, h) = (modX^h + modY^h - modXY^h) / 2
Expand Down
4 changes: 2 additions & 2 deletions src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν,
@functor MaternKernel

@inline function kappa(κ::MaternKernel, d::Real)
result = _matern(first(κ.ν), d)
result = _matern(only(κ.ν), d)
return ifelse(iszero(d), one(result), result)
end

Expand All @@ -46,7 +46,7 @@ end
metric(k::MaternKernel) = k.metric

function Base.show(io::IO, κ::MaternKernel)
return print(io, "Matern Kernel (ν = ", first(κ.ν), ", metric = ", κ.metric, ")")
return print(io, "Matern Kernel (ν = ", only(κ.ν), ", metric = ", κ.metric, ")")
end

## Matern12Kernel = ExponentialKernel aliased in exponential.jl
Expand Down
10 changes: 5 additions & 5 deletions src/basekernels/polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ LinearKernel(; c::Real=0.0) = LinearKernel(c)

@functor LinearKernel

kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + first(κ.c)
kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + only(κ.c)

metric(::LinearKernel) = DotProduct()

Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", first(κ.c), ")")
Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", only(κ.c), ")")

"""
PolynomialKernel(; degree::Int=2, c::Real=0.0)
Expand All @@ -53,7 +53,7 @@ struct PolynomialKernel{Tc<:Real} <: SimpleKernel

function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc}
@check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1")
@check_args(PolynomialKernel, c, first(c) >= zero(Tc), "c ≥ 0")
@check_args(PolynomialKernel, c, only(c) >= zero(Tc), "c ≥ 0")
return new{Tc}(degree, c)
end
end
Expand All @@ -68,10 +68,10 @@ function Functors.functor(::Type{<:PolynomialKernel}, x)
return (c=x.c,), reconstruct_polynomialkernel
end

kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + first(κ.c))^κ.degree
kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + only(κ.c))^κ.degree

metric(::PolynomialKernel) = DotProduct()

function Base.show(io::IO, κ::PolynomialKernel)
return print(io, "Polynomial Kernel (c = ", first(κ.c), ", degree = ", κ.degree, ")")
return print(io, "Polynomial Kernel (c = ", only(κ.c), ", degree = ", κ.degree, ")")
end
16 changes: 8 additions & 8 deletions src/basekernels/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ end
@functor RationalKernel

function kappa(κ::RationalKernel, d::Real)
return (one(d) + d / first(κ.α))^(-first(κ.α))
return (one(d) + d / only(κ.α))^(-only(κ.α))
end

metric(k::RationalKernel) = k.metric

function Base.show(io::IO, κ::RationalKernel)
return print(io, "Rational Kernel (α = ", first(κ.α), ", metric = ", κ.metric, ")")
return print(io, "Rational Kernel (α = ", only(κ.α), ", metric = ", κ.metric, ")")
end

"""
Expand Down Expand Up @@ -72,18 +72,18 @@ end
@functor RationalQuadraticKernel

function kappa(κ::RationalQuadraticKernel, d::Real)
return (one(d) + d^2 / (2 * first(κ.α)))^(-first(κ.α))
return (one(d) + d^2 / (2 * only(κ.α)))^(-only(κ.α))
end
function kappa(κ::RationalQuadraticKernel{<:Real,<:Euclidean}, d²::Real)
return (one(d²) + d² / (2 * first(κ.α)))^(-first(κ.α))
return (one(d²) + d² / (2 * only(κ.α)))^(-only(κ.α))
end

metric(k::RationalQuadraticKernel) = k.metric
metric(::RationalQuadraticKernel{<:Real,<:Euclidean}) = SqEuclidean()

function Base.show(io::IO, κ::RationalQuadraticKernel)
return print(
io, "Rational Quadratic Kernel (α = ", first(κ.α), ", metric = ", κ.metric, ")"
io, "Rational Quadratic Kernel (α = ", only(κ.α), ", metric = ", κ.metric, ")"
)
end

Expand Down Expand Up @@ -122,7 +122,7 @@ end
@functor GammaRationalKernel

function kappa(κ::GammaRationalKernel, d::Real)
return (one(d) + d^first(κ.γ) / first(κ.α))^(-first(κ.α))
return (one(d) + d^only(κ.γ) / only(κ.α))^(-only(κ.α))
end

metric(k::GammaRationalKernel) = k.metric
Expand All @@ -131,9 +131,9 @@ function Base.show(io::IO, κ::GammaRationalKernel)
return print(
io,
"Gamma Rational Kernel (α = ",
first(κ.α),
only(κ.α),
", γ = ",
first(κ.γ),
only(κ.γ),
", metric = ",
κ.metric,
")",
Expand Down
4 changes: 3 additions & 1 deletion src/distances/sinus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ struct Sinus{T} <: Distances.UnionSemiMetric
r::Vector{T}
end

Sinus(r::Real) = Sinus([r])

Distances.parameters(d::Sinus) = d.r
@inline Distances.eval_op(::Sinus, a::Real, b::Real, p::Real) = abs2(sinpi(a - b) / p)
@inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r))
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / only(dist.r))

Distances.result_type(::Sinus{T}, Ta::Type, Tb::Type) where {T} = promote_type(T, Ta, Tb)

Expand Down
4 changes: 2 additions & 2 deletions src/kernels/scaledkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end

@functor ScaledKernel

(k::ScaledKernel)(x, y) = first(k.σ²) * k.kernel(x, y)
(k::ScaledKernel)(x, y) = only(k.σ²) * k.kernel(x, y)

function kernelmatrix(κ::ScaledKernel, x::AbstractVector, y::AbstractVector)
return κ.σ² .* kernelmatrix(κ.kernel, x, y)
Expand Down Expand Up @@ -75,5 +75,5 @@ Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0)

function printshifted(io::IO, κ::ScaledKernel, shift::Int)
printshifted(io, κ.kernel, shift)
return print(io, "\n" * ("\t"^(shift + 1)) * "- σ² = $(first(κ.σ²))")
return print(io, "\n" * ("\t"^(shift + 1)) * "- σ² = $(only(κ.σ²))")
end
4 changes: 2 additions & 2 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ function (k::TransformedKernel{<:SimpleKernel,<:ScaleTransform})(
end

function _scale(t::ScaleTransform, metric::Euclidean, x, y)
return first(t.s) * evaluate(metric, x, y)
return only(t.s) * evaluate(metric, x, y)
end
function _scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y)
return first(t.s)^2 * evaluate(metric, x, y)
return only(t.s)^2 * evaluate(metric, x, y)
end
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))

Expand Down
2 changes: 1 addition & 1 deletion src/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end

dim(t::ARDTransform) = length(t.v)

(t::ARDTransform)(x::Real) = first(t.v) * x
(t::ARDTransform)(x::Real) = only(t.v) * x
(t::ARDTransform)(x) = t.v .* x

_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
Expand Down
8 changes: 4 additions & 4 deletions src/transform/periodic_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ PeriodicTransform(f::Real) = PeriodicTransform([f])

dim(t::PeriodicTransform) = 2

(t::PeriodicTransform)(x::Real) = [sinpi(2 * first(t.f) * x), cospi(2 * first(t.f) * x)]
(t::PeriodicTransform)(x::Real) = [sinpi(2 * only(t.f) * x), cospi(2 * only(t.f) * x)]

function _map(t::PeriodicTransform, x::AbstractVector{<:Real})
return RowVecs(hcat(sinpi.((2 * first(t.f)) .* x), cospi.((2 * first(t.f)) .* x)))
return RowVecs(hcat(sinpi.((2 * only(t.f)) .* x), cospi.((2 * only(t.f)) .* x)))
end

function Base.isequal(t1::PeriodicTransform, t2::PeriodicTransform)
return isequal(first(t1.f), first(t2.f))
return isequal(only(t1.f), only(t2.f))
end

function Base.show(io::IO, t::PeriodicTransform)
return print(io, "Periodic Transform with frequency $(first(t.f))")
return print(io, "Periodic Transform with frequency $(only(t.f))")
end
12 changes: 6 additions & 6 deletions src/transform/scaletransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ end

set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ]

(t::ScaleTransform)(x) = first(t.s) * x
(t::ScaleTransform)(x) = only(t.s) * x

_map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x
_map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)
_map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X)
_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x
_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)

Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(first(t.s), first(t2.s))
Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s))

Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", first(t.s), ")")
Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", only(t.s), ")")
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Expand Down
2 changes: 1 addition & 1 deletion test/basekernels/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@

# Standardised tests.
TestUtils.test_interface(k, Float64)
test_ADs(c -> ConstantKernel(; c=first(c)), [c])
test_ADs(c -> ConstantKernel(; c=only(c)), [c])
end
end
2 changes: 1 addition & 1 deletion test/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
@test metric(k2) isa WeightedEuclidean
@test k2(v1, v2) ≈ k(v1, v2)

test_ADs(γ -> GammaExponentialKernel(; gamma=first(γ)), [1 + 0.5 * rand()])
test_ADs(γ -> GammaExponentialKernel(; gamma=only(γ)), [1 + 0.5 * rand()])
test_params(k, ([γ],))
TestUtils.test_interface(GammaExponentialKernel(; γ=1.36))

Expand Down
3 changes: 2 additions & 1 deletion test/distances/sinus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
d = KernelFunctions.Sinus(p)
@test Distances.parameters(d) == p
@test evaluate(d, A, B) == sum(abs2.(sinpi.(A - B) ./ p))
@test d(3.0, 2.0) == abs2(sinpi(3.0 - 2.0) / first(p))
d1 = KernelFunctions.Sinus(first(p))
@test d1(3.0, 2.0) == abs2(sinpi(3.0 - 2.0) / first(p))
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using Zygote: Zygote
using ForwardDiff: ForwardDiff
using ReverseDiff: ReverseDiff
using FiniteDifferences: FiniteDifferences
using Compat: only

using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils

Expand Down
4 changes: 2 additions & 2 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const FDM = FiniteDifferences.central_fdm(5, 1)
gradient(f, s::Symbol, args) = gradient(f, Val(s), args)

function gradient(f, ::Val{:Zygote}, args)
g = first(Zygote.gradient(f, args))
g = only(Zygote.gradient(f, args))
if isnothing(g)
if args isa AbstractArray{<:Real}
return zeros(size(args)) # To respect the same output as other ADs
Expand All @@ -66,7 +66,7 @@ function gradient(f, ::Val{:ReverseDiff}, args)
end

function gradient(f, ::Val{:FiniteDiff}, args)
return first(FiniteDifferences.grad(FDM, f, args))
return only(FiniteDifferences.grad(FDM, f, args))
end

function compare_gradient(f, ::Val{:FiniteDiff}, args)
Expand Down