-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] support instance_norm #296
Conversation
042ab0a
to
c50e12a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for your contribution! please modify the code or respond to the comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
review done.
@@ -70,9 +88,15 @@ def layernorm_input_fn(shape, dtype, device): | |||
layernorm_input_fn, | |||
marks=pytest.mark.layer_norm, | |||
), | |||
pytest.param( | |||
"instance_norm", | |||
torch.instance_norm, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.nn.functional.instance_norm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found that torch.nn.functional.instance_norm
uses a different API than torch.instance_norm
.
This API here is exactly torch.instance_norm
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eps, | ||
HAS_WEIGHT_BIAS=has_weight_bias, | ||
) | ||
if has_running_stats and use_input_stats: # update running stats |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant condition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually No.
has_running_stats
means running_mean
and running_var
is not None.
use_input_stats
means use mean
and rstd
from input, otherwise uses running_mean
and running_var
from outside.
Condition when updating running stats, is has_running_stats
and use_input_stats
be True simultaneously.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_input_stats is checked in line 500. so there is no need to check it again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can pull the latest code to Pass the CI
OK, I will rebase to the latest master soon |
a1bcd4a
to
3876ce6
Compare
Signed-off-by: ZQPei <[email protected]>
Signed-off-by: ZQPei <[email protected]>
Signed-off-by: ZQPei <[email protected]>
Signed-off-by: ZQPei <[email protected]>
Signed-off-by: ZQPei <[email protected]>
Signed-off-by: ZQPei <[email protected]>
Signed-off-by: ZQPei <[email protected]>
Signed-off-by: ZQPei <[email protected]>
Signed-off-by: ZQPei <[email protected]>
Signed-off-by: ZQPei <[email protected]>
Signed-off-by: ZQPei <[email protected]>
Signed-off-by: ZQPei <[email protected]>
3876ce6
to
06e0ba4
Compare
Signed-off-by: ZQPei <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
PR Category
Operator
Type of Change
New Feature
Description
Support instance norm with forward and backward, w/wo running stats
Issue
Progress
Performance
Environment:
A100 nvlink 80G
pytorch 1.13
cuda 11.8