diff --git a/Project.toml b/Project.toml index 8292de97e2..da30569e78 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ ChainRulesCore = "1.12" Functors = "0.3, 0.4" MLUtils = "0.2, 0.3.1, 0.4" MacroTools = "0.5" -NNlib = "0.8.15" +NNlib = "0.8.17" NNlibCUDA = "0.2.6" OneHotArrays = "0.1, 0.2" Optimisers = "0.2.12" diff --git a/src/Flux.jl b/src/Flux.jl index afa47f6fc0..23954c8e17 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -21,7 +21,7 @@ Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zyg export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion, RNN, LSTM, GRU, GRUv3, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, - AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, + AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, GlobalLPNormPool, MaxPool, MeanPool, LPNormPool, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, Upsample, PixelShuffle, fmap, cpu, gpu, f32, f64, rand32, randn32, zeros32, ones32, diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 8a63b89bcc..40eb80866b 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -603,10 +603,6 @@ function (g::GlobalMaxPool)(x) return maxpool(x, pdims) end -function Base.show(io::IO, g::GlobalMaxPool) - print(io, "GlobalMaxPool()") -end - """ GlobalMeanPool() @@ -637,8 +633,35 @@ function (g::GlobalMeanPool)(x) return meanpool(x, pdims) end -function Base.show(io::IO, g::GlobalMeanPool) - print(io, "GlobalMeanPool()") +""" + GlobalLPNormPool(p::Real) + +Global lp norm pooling layer. + +Transform (w,h,c,b)-shaped input into (1,1,c,b)-shaped output, +by performing lp norm pooling on the complete (w,h)-shaped feature maps. +And expects input `x` to satisfy `all(x .>= 0)` to avoid DomainError. + +See also [`LPNormPool`](@ref). + +```jldoctest +julia> xs = rand(Float32, 100, 100, 3, 50); + +julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool(2.0)); + +julia> m(xs) |> size +(1, 1, 7, 50) +``` +""" +struct GlobalLPNormPool + p::Real +end + +function (g::GlobalLPNormPool)(x) + x_size = size(x) + k = x_size[1:end-2] + pdims = PoolDims(x, k) + return lpnormpool(x, pdims; p=g.p) end """ @@ -762,3 +785,72 @@ function Base.show(io::IO, m::MeanPool) m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride)) print(io, ")") end + +""" + LPNormPool(window::NTuple, p::Real; pad=0, stride=window) + +Lp norm pooling layer, calculating p-norm distance for each window, +also known as LPPool in pytorch. + +Expects as input an array with `ndims(x) == N+2`, i.e. channel and +batch dimensions, after the `N` feature dimensions, where `N = length(window)`. +Also expects `all(x .>= 0)` to avoid DomainError. + +By default the window size is also the stride in each dimension. +The keyword `pad` accepts the same options as for the `Conv` layer, +including `SamePad()`. + +See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalLPNormPool`](@ref), +[`pytorch LPPool`](https://pytorch.org/docs/stable/generated/torch.nn.LPPool2d.html). + +# Examples + +```jldoctest +julia> xs = rand(Float32, 100, 100, 3, 50); + +julia> m = Chain(Conv((5,5), 3 => 7), LPNormPool((5,5), 2.0; pad=SamePad())) +Chain( + Conv((5, 5), 3 => 7), # 532 parameters + LPNormPool((5, 5), 2.0, pad=2), +) + +julia> m[1](xs) |> size +(96, 96, 7, 50) + +julia> m(xs) |> size +(20, 20, 7, 50) + +julia> layer = LPNormPool((5,), 2.0, pad=2, stride=(3,)) # one-dimensional window +LPNormPool((5,), 2.0, pad=2, stride=3) + +julia> layer(rand(Float32, 100, 7, 50)) |> size +(34, 7, 50) +``` +""" +struct LPNormPool{N,M} + k::NTuple{N,Int} + p::Real + pad::NTuple{M,Int} + stride::NTuple{N,Int} +end + +function LPNormPool(k::NTuple{N,Integer}, p::Real; pad = 0, stride = k) where {N} + stride = expand(Val(N), stride) + pad = calc_padding(LPNormPool, pad, k, 1, stride) + return LPNormPool(k, p, pad, stride) +end + +function (l::LPNormPool)(x) + iseven(l.p) || ChainRulesCore.@ignore_derivatives if any(<(0), x) + throw(DomainError("LPNormPool requires x to be non-negative")) + end + pdims = PoolDims(x, l.k; padding=l.pad, stride=l.stride) + return lpnormpool(x, pdims; p=l.p) +end + +function Base.show(io::IO, l::LPNormPool) + print(io, "LPNormPool(", l.k, ", ", l.p) + all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad)) + l.stride == l.k || print(io, ", stride=", _maybetuple_string(l.stride)) + print(io, ")") +end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index c83b2c18d3..2844c55d9b 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -17,10 +17,14 @@ using Flux: gradient @test size(gmp(x)) == (1, 1, 3, 2) gmp = GlobalMeanPool() @test size(gmp(x)) == (1, 1, 3, 2) + glmp = GlobalLPNormPool(2.0) + @test size(glmp(x)) == (1, 1, 3, 2) mp = MaxPool((2, 2)) @test mp(x) == maxpool(x, PoolDims(x, 2)) mp = MeanPool((2, 2)) @test mp(x) == meanpool(x, PoolDims(x, 2)) + lnp = LPNormPool((2,2), 2.0) + @test lnp(x) == lpnormpool(x, PoolDims(x, 2); p=2.0) end @testset "CNN" begin