Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add diagonal backward #329

Merged
merged 5 commits into from
Dec 5, 2024

Conversation

awayzjj
Copy link
Collaborator

@awayzjj awayzjj commented Nov 26, 2024

PR Category

Type of Change

Description

Issue

Close #314

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

image
image

@awayzjj
Copy link
Collaborator Author

awayzjj commented Nov 26, 2024

@StrongSpoon Hi, I have 2 questions.

  1. How to registered backward independently into aten library, I could not find a reference in the repo, so I still implemented a class with forward and backward function as a draft.
  2. I implemented the backward function, and the UT passed.
    def backward(ctx, out_grad):
        logging.debug("GEMS DIAGONAL BACKWARD")
        (inp,) = ctx.saved_tensors
        grad_input = torch.zeros_like(inp)
        diag = torch.diagonal(grad_input, ctx.offset, ctx.dim1, ctx.dim2)
        diag.copy_(out_grad)
        return grad_input, None, None, None

it is simple, use torch.diagonal to get a view of the grad_input, and copy the value from the out_grad, I wonder if we need a triton kernel here.

Thank you very much!

@awayzjj
Copy link
Collaborator Author

awayzjj commented Nov 28, 2024

@StrongSpoon A gentle reminder.

@StrongSpoon Hi, I have 2 questions.

  1. How to registered backward independently into aten library, I could not find a reference in the repo, so I still implemented a class with forward and backward function as a draft.
  2. I implemented the backward function, and the UT passed.
    def backward(ctx, out_grad):
        logging.debug("GEMS DIAGONAL BACKWARD")
        (inp,) = ctx.saved_tensors
        grad_input = torch.zeros_like(inp)
        diag = torch.diagonal(grad_input, ctx.offset, ctx.dim1, ctx.dim2)
        diag.copy_(out_grad)
        return grad_input, None, None, None

it is simple, use torch.diagonal to get a view of the grad_input, and copy the value from the out_grad, I wonder if we need a triton kernel here.

Thank you very much!

@StrongSpoon
Copy link
Collaborator

Hi awayzjj,

We used to implement operator with both forward and backward as a subclass of torch.autograd.Function, and register it to the forward interface using AutogradCUDA key. But recently we found that AutogradCUDA could not work with torch.compile perfectly. As a solution, we recommend to implement forward function and backward function, and register them into aten library respectively.
Take tanh as an example, its forward interface is defined in https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml. In the previous practice, we implemented class Tanh and registered it into tanh interface. But it's better to register forward function to tanh and backward function to tanh_backward, using CUDA as dispatch key.

- func: tanh(Tensor self) -> Tensor
  device_check: NoCheck   # TensorIterator
  structured_delegate: tanh.out
  variants: function, method
  dispatch:
    QuantizedCPU: tanh_quantized_cpu
    MkldnnCPU: mkldnn_tanh
    SparseCPU, SparseCUDA: tanh_sparse
    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tanh_sparse_csr
    NestedTensorCPU, NestedTensorCUDA: NestedTensor_tanh
  tags: [core, pointwise]
- func: tanh_backward(Tensor grad_output, Tensor output) -> Tensor
  python_module: nn
  structured_delegate: tanh_backward.grad_input

@StrongSpoon
Copy link
Collaborator

As is the same, there is a definition of diagonal_backward interface, which is expected to be reimplemented by developers.

@StrongSpoon
Copy link
Collaborator

Besides, we require developer to implement the function by writing a Triton kernel function instead of torch apis. If you are confused about the format, please refer to our source code.

@StrongSpoon StrongSpoon self-assigned this Dec 4, 2024
@awayzjj awayzjj requested a review from StrongSpoon December 5, 2024 01:31
@awayzjj
Copy link
Collaborator Author

awayzjj commented Dec 5, 2024

Besides, we require developer to implement the function by writing a Triton kernel function instead of torch apis. If you are confused about the format, please refer to our source code.

Please review my PR, thanks!

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please provide the benchmark results.

return grad_input


def diagonal_backward(grad_output, input_sizes, offset, dim1, dim2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's okay to fuse backward and diagonal_backward into one.

input_sizes, dtype=grad_output.dtype, device=grad_output.device
)
diag = torch.diagonal(grad_input, offset, dim1, dim2)
copy_func.instantiate(grad_output.ndim)(grad_output, out0=diag)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since torch.zeros also calls for a kernel, there exist two kernels indeed. I wonder if it's feasible to initialize grad_input as an empty tensor, and assign it in one kernel function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I'll give it a try.

res_out = to_reference(res_out)
res_in_grad = to_reference(res_in_grad)
gems_assert_equal(res_out, ref_out)
gems_assert_close(res_in_grad, ref_in_grad, dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not require them equal? I thought backward function doesn't change the value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I fixed it.

@awayzjj
Copy link
Collaborator Author

awayzjj commented Dec 5, 2024

Hi, @StrongSpoon
CI failed, but the failed UT is not relavent with my PR(I can reproduce it after runing the UT several times on the latest master branch)
image

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@StrongSpoon StrongSpoon merged commit 923567e into FlagOpen:master Dec 5, 2024
8 of 9 checks passed
StrongSpoon pushed a commit that referenced this pull request Dec 12, 2024
* diagonal v0

* impl triton version

* fix code format

* fix --ref cpu failed

* use gems_assert_equal to validate res_in_grad
DuanYaQi pushed a commit that referenced this pull request Dec 17, 2024
* diagonal v0

* impl triton version

* fix code format

* fix --ref cpu failed

* use gems_assert_equal to validate res_in_grad
Gxiandy pushed a commit to Gxiandy/FlagGems that referenced this pull request Jan 12, 2025
* diagonal v0

* impl triton version

* fix code format

* fix --ref cpu failed

* use gems_assert_equal to validate res_in_grad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Code Contribution: 【Lv1】【Operator Development】diagonal_backward
2 participants