@@ -18,12 +18,12 @@ for higher order derivatives partial can be any iterable, i.e.
18
18
k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y))
19
19
```
20
20
"""
21
- struct DiffPt{Dim}
21
+ struct DiffPt
22
22
pos # the actual position
23
23
partial
24
24
end
25
25
26
- DiffPt (x; partial= ()) = DiffPt {length(x)} (x, partial) # convenience constructor
26
+ DiffPt (x; partial= ()) = DiffPt (x, partial) # convenience constructor
27
27
28
28
"""
29
29
partial(fun, idx)
@@ -34,10 +34,8 @@ Return ∂ᵢf where
34
34
"""
35
35
function partial (fun, idx)
36
36
return x -> FD. derivative (0 ) do dx
37
- y = similar (x)
38
- y = copyto! (y, x)
39
- y[idx] += dx
40
- fun (y)
37
+ dim = length (x)
38
+ fun (x .+ dx * OneHotVector (idx, dim))
41
39
end
42
40
end
43
41
58
56
Take the partial derivative of a function with two dim-dimensional inputs,
59
57
i.e. 2*dim dimensional input
60
58
"""
61
- function partial (k, dim ; partials_x= (), partials_y= ())
62
- local f (x, y) = partial (t -> k (t, y), dim, partials_x)(x)
63
- return (x, y) -> partial (t -> f (x, t), dim, partials_y)(y)
59
+ function partial (k; partials_x= (), partials_y= ())
60
+ local f (x, y) = partial (t -> k (t, y), partials_x... )(x)
61
+ return (x, y) -> partial (t -> f (x, t), partials_y... )(y)
64
62
end
65
63
66
64
"""
67
- _evaluate(k::T, x::DiffPt{Dim} , y::DiffPt{Dim} ) where {Dim, T<:Kernel}
65
+ _evaluate(k::T, x::DiffPt, y::DiffPt) where {T<:Kernel}
68
66
69
- implements `(k::T)(x::DiffPt{Dim} , y::DiffPt{Dim} )` for all kernel types. But since
67
+ implements `(k::T)(x::DiffPt, y::DiffPt)` for all kernel types. But since
70
68
generics are not allowed in the syntax above by the dispatch system, this
71
69
redirection over `_evaluate` is necessary
72
70
73
71
unboxes the partial instructions from DiffPt and applies them to k,
74
72
evaluates them at the positions of DiffPt
75
73
"""
76
- function _evaluate (k:: T , x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim, T<: Kernel }
77
- return partial (k, Dim; partials_x= x. partial, partials_y= y. partial)(x. pos, y. pos)
74
+ function _evaluate (k:: T , x:: DiffPt , y:: DiffPt ) where {T<: Kernel }
75
+ return partial (k, partials_x= x. partial, partials_y= y. partial)(x. pos, y. pos)
78
76
end
79
77
80
78
#=
@@ -101,7 +99,7 @@ for T in [
101
99
NormalizedKernel,
102
100
KernelTensorProduct
103
101
] # subtypes(Kernel)
104
- (k:: T )(x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, x, y)
105
- (k:: T )(x:: DiffPt{Dim} , y) where {Dim} = _evaluate (k, x, DiffPt (y))
106
- (k:: T )(x, y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, DiffPt (x), y)
102
+ (k:: T )(x:: DiffPt , y:: DiffPt ) = _evaluate (k, x, y)
103
+ (k:: T )(x:: DiffPt , y) = _evaluate (k, x, DiffPt (y))
104
+ (k:: T )(x, y:: DiffPt ) = _evaluate (k, DiffPt (x), y)
107
105
end
0 commit comments