-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Add diagonal backward #329
Conversation
@StrongSpoon Hi, I have 2 questions.
def backward(ctx, out_grad):
logging.debug("GEMS DIAGONAL BACKWARD")
(inp,) = ctx.saved_tensors
grad_input = torch.zeros_like(inp)
diag = torch.diagonal(grad_input, ctx.offset, ctx.dim1, ctx.dim2)
diag.copy_(out_grad)
return grad_input, None, None, None it is simple, use Thank you very much! |
@StrongSpoon A gentle reminder.
|
Hi awayzjj, We used to implement operator with both forward and backward as a subclass of torch.autograd.Function, and register it to the forward interface using AutogradCUDA key. But recently we found that AutogradCUDA could not work with torch.compile perfectly. As a solution, we recommend to implement forward function and backward function, and register them into aten library respectively.
|
As is the same, there is a definition of diagonal_backward interface, which is expected to be reimplemented by developers. |
Besides, we require developer to implement the function by writing a Triton kernel function instead of torch apis. If you are confused about the format, please refer to our source code. |
b66ae57
to
2851a03
Compare
Please review my PR, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please provide the benchmark results.
src/flag_gems/ops/diagonal.py
Outdated
return grad_input | ||
|
||
|
||
def diagonal_backward(grad_output, input_sizes, offset, dim1, dim2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's okay to fuse backward and diagonal_backward into one.
input_sizes, dtype=grad_output.dtype, device=grad_output.device | ||
) | ||
diag = torch.diagonal(grad_input, offset, dim1, dim2) | ||
copy_func.instantiate(grad_output.ndim)(grad_output, out0=diag) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since torch.zeros also calls for a kernel, there exist two kernels indeed. I wonder if it's feasible to initialize grad_input as an empty tensor, and assign it in one kernel function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, I'll give it a try.
tests/test_special_ops.py
Outdated
res_out = to_reference(res_out) | ||
res_in_grad = to_reference(res_in_grad) | ||
gems_assert_equal(res_out, ref_out) | ||
gems_assert_close(res_in_grad, ref_in_grad, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not require them equal? I thought backward function doesn't change the value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I fixed it.
e000f6e
to
6515873
Compare
Hi, @StrongSpoon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
* diagonal v0 * impl triton version * fix code format * fix --ref cpu failed * use gems_assert_equal to validate res_in_grad
* diagonal v0 * impl triton version * fix code format * fix --ref cpu failed * use gems_assert_equal to validate res_in_grad
* diagonal v0 * impl triton version * fix code format * fix --ref cpu failed * use gems_assert_equal to validate res_in_grad
PR Category
Type of Change
Description
Issue
Close #314
Progress
Performance