-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathfunctiontransform.jl
44 lines (32 loc) · 1.01 KB
/
functiontransform.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
FunctionTransform(f)
Transformation that applies function `f` to the input.
Make sure that `f` can act on an input. For instance, if the inputs are vectors, use
`f(x) = sin.(x)` instead of `f = sin`.
# Examples
```jldoctest
julia> f(x) = sum(x); t = FunctionTransform(f); X = randn(100, 10);
julia> map(t, ColVecs(X)) == ColVecs(sum(X; dims=1))
true
```
"""
struct FunctionTransform{F} <: Transform
f::F
end
@functor FunctionTransform
(t::FunctionTransform)(x) = t.f(x)
Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
function Base.map(t::FunctionTransform, x::ColVecs)
vals = map(axes(x.X, 2)) do i
t.f(view(x.X, :, i))
end
return ColVecs(reduce(hcat, vals))
end
function Base.map(t::FunctionTransform, x::RowVecs)
vals = map(axes(x.X, 1)) do i
t.f(view(x.X, i, :))
end
return RowVecs(reduce(hcat, vals)')
end
duplicate(t::FunctionTransform, f) = FunctionTransform(f)
Base.show(io::IO, t::FunctionTransform) = print(io, "Function Transform: ", t.f)