Skip to content

Commit 0f93c1c

Browse files
committed
attempt a fix
1 parent e787af4 commit 0f93c1c

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

src/diffKernel.jl

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ for higher order derivatives partial can be any iterable, i.e.
1818
k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y))
1919
```
2020
"""
21-
struct DiffPt{Dim}
21+
struct DiffPt
2222
pos # the actual position
2323
partial
2424
end
2525

26-
DiffPt(x; partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor
26+
DiffPt(x; partial=()) = DiffPt(x, partial) # convenience constructor
2727

2828
"""
2929
partial(fun, idx)
@@ -34,10 +34,8 @@ Return ∂ᵢf where
3434
"""
3535
function partial(fun, idx)
3636
return x -> FD.derivative(0) do dx
37-
y = similar(x)
38-
y = copyto!(y, x)
39-
y[idx] += dx
40-
fun(y)
37+
dim = length(x)
38+
fun(x .+ dx * OneHotVector(idx, dim))
4139
end
4240
end
4341

@@ -58,23 +56,23 @@ end
5856
Take the partial derivative of a function with two dim-dimensional inputs,
5957
i.e. 2*dim dimensional input
6058
"""
61-
function partial(k, dim; partials_x=(), partials_y=())
62-
local f(x, y) = partial(t -> k(t, y), dim, partials_x)(x)
63-
return (x, y) -> partial(t -> f(x, t), dim, partials_y)(y)
59+
function partial(k; partials_x=(), partials_y=())
60+
local f(x, y) = partial(t -> k(t, y), partials_x...)(x)
61+
return (x, y) -> partial(t -> f(x, t), partials_y...)(y)
6462
end
6563

6664
"""
67-
_evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
65+
_evaluate(k::T, x::DiffPt, y::DiffPt) where {T<:Kernel}
6866
69-
implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since
67+
implements `(k::T)(x::DiffPt, y::DiffPt)` for all kernel types. But since
7068
generics are not allowed in the syntax above by the dispatch system, this
7169
redirection over `_evaluate` is necessary
7270
7371
unboxes the partial instructions from DiffPt and applies them to k,
7472
evaluates them at the positions of DiffPt
7573
"""
76-
function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim,T<:Kernel}
77-
return partial(k, Dim; partials_x=x.partial, partials_y=y.partial)(x.pos, y.pos)
74+
function _evaluate(k::T, x::DiffPt, y::DiffPt) where {T<:Kernel}
75+
return partial(k, partials_x=x.partial, partials_y=y.partial)(x.pos, y.pos)
7876
end
7977

8078
#=
@@ -101,7 +99,7 @@ for T in [
10199
NormalizedKernel,
102100
KernelTensorProduct
103101
] #subtypes(Kernel)
104-
(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y)
105-
(k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y))
106-
(k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y)
102+
(k::T)(x::DiffPt, y::DiffPt)= _evaluate(k, x, y)
103+
(k::T)(x::DiffPt, y) = _evaluate(k, x, DiffPt(y))
104+
(k::T)(x, y::DiffPt) = _evaluate(k, DiffPt(x), y)
107105
end

0 commit comments

Comments
 (0)