-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LPNormPool #2166
base: master
Are you sure you want to change the base?
Add LPNormPool #2166
Conversation
Codecov ReportBase: 87.05% // Head: 84.49% // Decreases project coverage by
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more Additional details and impacted files@@ Coverage Diff @@
## master #2166 +/- ##
==========================================
- Coverage 87.05% 84.49% -2.56%
==========================================
Files 19 19
Lines 1491 1503 +12
==========================================
- Hits 1298 1270 -28
- Misses 193 233 +40
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really good, thanks! A couple small comments.
src/layers/conv.jl
Outdated
LPNormPool(window::NTuple, p::Number; pad=0, stride=window) | ||
|
||
Lp norm pooling layer, calculating p-norm distance for each window, | ||
also known as LPPool in pytorch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One vote to call this NormPool
. One reason besides brevity is that it'll be hard to remember whether the p is uppercase -- you write "Lp norm" just below, which I think is standard (lower-case, ideally subscript) and suggests LpNormPool
.
I think it's good to mention pytorch's name, to be searchable. But IMO this should be somewhere down next to "See also", not in the first sentence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still have no idea about struct name. Hope for more discussion :)
About pytorch metion, The reason I add this here is to make user find it is identical to LPPool
in pytorch more quickly. I think it's nice to append LPPool
in "See also". I'll add it in new commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the PyTorch implementation, there does appear to be a meaningful difference from taking the norm
of every window. Lp-pooling lacks the absolute value used in the p-norm, so negative inputs remain negative at odd powers. Unless we want to deviate from PyTorch, I think that lends credence to not calling this "NormPool".
>>> x = torch.tensor([-1, -2, -3, 1, 2, 3]).reshape(1, 1, -1)
>>> torch.nn.functional.lp_pool1d(x, 1, 1)
tensor([[[-1., -2., -3., 1., 2., 3.]]])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's very different. What does it do for power 3, and for 1.5?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> torch.nn.functional.lp_pool1d(x, 3, 1)
tensor([[[nan, nan, nan, 1., 2., 3.]]])
>>> torch.nn.functional.lp_pool1d(x, 1.5, 1)
tensor([[[nan, nan, nan, 1., 2., 3.]]])
Compare:
>>> torch.linalg.norm(torch.tensor([-1.]), ord=1)
tensor(1.)
>>> torch.linalg.norm(torch.tensor([-1.]), ord=3)
tensor(1.)
>>> torch.linalg.norm(torch.tensor([-1.]), ord=1.5)
tensor(1.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that there are other choices which would agree with this one when x is positive:
- Could take
norm(xs, p)
, as theabs
does nothing except when at present you'll get an error. The NNlib function (and the pytorch one) do allow negative x when p=1, and this would change the answer.- Could take
y = sum(x -> sign(x) * abs(x)^p, xs); sign(y) * abs(x)^(1/p)
.
Oh, yes. I will update for this more carefully.
Could do.
iseven(2.0), iseven(2.01)
seems to behave well.
And this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My question is whether I can use try-catch
to catch DomainError thrown by function ^(x, y)
after skimming After skimming https://fluxml.ai/Zygote.jl/stable/limitations/#Try-catch-statements-1. function ^(x, y)
allows even exp and performs well in case of special exp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think try-catch is the way to go. This function should decide what it accepts, and check its input. Note for example that CUDA doesn't always follow Base in these things:
julia> cu([-1 0 1]) .^ 2.1f0
1×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
NaN 0.0 1.0
julia> [-1 0 1] .^ 2.1f0
ERROR: DomainError with -1.0:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing check is fine, no try-catch required. The short-circuiting I was referring to was inserting a iseven(p) &&
before the any
check for negative elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think try-catch is the way to go. This function should decide what it accepts, and check its input.
Make sense! I learn a lot.
e45eaec
to
85a6eff
Compare
… in Flux, rather than 'function ^(x, y)' in Base.Math
PR Checklist
About #1431 Pooling Layers |
LPPool1d
,LPPool2d
.