Skip to content

Commit 899d3d3

Browse files
ahmadsharif1pytorchmergebot
authored andcommitted
Don't call sum() on a tensor that is not summable in layer_norm (pytorch#156600)
Don't call `sum()` on a tensor that is default constructed. Previously we could call `sum()` on a tensor that was default-contructed. That would lead to an error like this: ``` Traceback (most recent call last): File "/home/ahmads/.conda/envs/pt3/lib/python3.12/unittest/case.py", line 58, in testPartExecutor yield File "/home/ahmads/.conda/envs/pt3/lib/python3.12/unittest/case.py", line 634, in run self._callTestMethod(testMethod) File "/home/ahmads/.conda/envs/pt3/lib/python3.12/unittest/case.py", line 589, in _callTestMethod if method() is not None: ^^^^^^^^ File "/home/ahmads/personal/pytorch/torch/testing/_internal/common_utils.py", line 3191, in wrapper method(*args, **kwargs) File "/home/ahmads/personal/pytorch/test/test_nn.py", line 7235, in test_layer_norm_backwards_eps ln_out_cuda.backward(grad_output_cuda) File "/home/ahmads/personal/pytorch/torch/_tensor.py", line 647, in backward torch.autograd.backward( File "/home/ahmads/personal/pytorch/torch/autograd/__init__.py", line 354, in backward _engine_run_backward( File "/home/ahmads/personal/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: tensor does not have a device Exception raised from device_default at /home/ahmads/personal/pytorch/c10/core/TensorImpl.h:1265 (most recent call first): C++ CapturedTraceback: #4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0 #5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0 #6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) from ??:0 #7 at::TensorBase::options() const from :0 #8 at::meta::resize_reduction(at::impl::MetaBase&, at::Tensor const&, c10::OptionalArrayRef<long>, bool, c10::ScalarType, bool) from :0 #9 at::meta::structured_sum_dim_IntList::meta(at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>) from ??:0 #10 at::(anonymous namespace)::wrapper_CompositeExplicitAutogradNonFunctional_sum_dim_IntList(at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>) from RegisterCompositeExplicitAutogradNonFunctional_0.cpp:0 #11 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>), &at::(anonymous namespace)::wrapper_CompositeExplicitAutogradNonFunctional_sum_dim_IntList>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType> > >, at::Tensor (at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>) from RegisterCompositeExplicitAutogradNonFunctional_0.cpp:0 #12 at::_ops::sum_dim_IntList::call(at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>) from ??:0 #13 void at::native::(anonymous namespace)::LaunchGammaBetaBackwardCUDAKernel<float, float>(float const*, float const*, float const*, float const*, long, long, at::Tensor*, at::Tensor*, CUstream_st*) from ??:0 #14 void at::native::(anonymous namespace)::LayerNormBackwardKernelImplInternal<float>(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long, long, at::Tensor*, at::Tensor*, at::Tensor*) from ??:0 #15 at::native::(anonymous namespace)::LayerNormBackwardKernelImpl(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long, long, at::Tensor*, at::Tensor*, at::Tensor*) from ??:0 #16 at::native::layer_norm_backward_cuda(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::array<bool, 3ul>) from ??:0 #17 at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA__native_layer_norm_backward(at::Tensor const&, at::Tensor const&, c10::ArrayRef<c10::SymInt>, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::array<bool, 3ul>) from RegisterCUDA_0.cpp:0 ``` Now we only call `sum(0)` on tensors that are defined and properly guard the `sum(0)` and assignment. Pull Request resolved: pytorch#156600 Approved by: https://github.com/eqy, https://github.com/ngimel
1 parent 17eb649 commit 899d3d3

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

aten/src/ATen/native/cuda/layer_norm_kernel.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -884,8 +884,12 @@ void LaunchGammaBetaBackwardCUDAKernel(
884884
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true>(
885885
aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr);
886886

887-
*dgamma = dgamma_blocks.sum(0);
888-
*dbeta = dbeta_blocks.sum(0);
887+
if (dgamma_blocks.defined()) {
888+
*dgamma = dgamma_blocks.sum(0);
889+
}
890+
if (dbeta_blocks.defined()) {
891+
*dbeta = dbeta_blocks.sum(0);
892+
}
889893
} else {
890894
// We are in the normal case where M is not that large.
891895
// We can change the tile shape (which is the last template parameter) in accordance with M.

test/test_nn.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7212,25 +7212,32 @@ def test_layer_norm_eps(self):
72127212
ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False)
72137213
self.assertEqual(ln.forward(x), torch.zeros_like(x))
72147214

7215+
72157216
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
72167217
def test_layer_norm_backwards_eps(self):
72177218
dtype = torch.float
72187219
m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55),
72197220
(32, 32), (1024, 32), (1024, 1024),
7220-
(33, 33), (1025, 33), (1025, 1025)]
7221-
for m, n in m_x_n_list:
7222-
x = torch.randn((m, n), dtype=dtype, requires_grad=True)
7223-
grad_output = torch.rand_like(x)
7224-
x_cuda = x.clone().detach().to("cuda").requires_grad_()
7225-
grad_output_cuda = grad_output.clone().detach().to("cuda")
7226-
ln = nn.LayerNorm(n, dtype=dtype)
7227-
ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype)
7228-
ln_out = ln(x)
7229-
ln_out_cuda = ln_cuda(x_cuda)
7230-
ln_out.backward(grad_output)
7231-
ln_out_cuda.backward(grad_output_cuda)
7232-
self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
7233-
self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
7221+
(33, 33), (1025, 33), (1025, 1025),
7222+
(128 * 1024, 32), (32, 128 * 1024)]
7223+
boolean = [True, False]
7224+
combinations = itertools.product(boolean, repeat=2)
7225+
for elementwise_affine, bias in combinations:
7226+
for m, n in m_x_n_list:
7227+
x = torch.randn((m, n), dtype=dtype, requires_grad=True)
7228+
grad_output = torch.rand_like(x)
7229+
x_cuda = x.clone().detach().to("cuda").requires_grad_()
7230+
grad_output_cuda = grad_output.clone().detach().to("cuda")
7231+
ln = nn.LayerNorm(n, dtype=dtype, elementwise_affine=elementwise_affine, bias=bias)
7232+
ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype, elementwise_affine=elementwise_affine, bias=bias)
7233+
ln_out = ln(x)
7234+
ln_out_cuda = ln_cuda(x_cuda)
7235+
ln_out.backward(grad_output)
7236+
ln_out_cuda.backward(grad_output_cuda)
7237+
if elementwise_affine:
7238+
self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-4, atol=1e-4)
7239+
if bias and elementwise_affine:
7240+
self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
72347241

72357242
@largeTensorTest("40GB", device="cuda")
72367243
def test_layer_norm_large_tensor(self):

0 commit comments

Comments
 (0)