From ad1100c41627d45171fa529c2cb2c22477989159 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 13 Sep 2025 12:46:10 -0300 Subject: [PATCH] Improve error messages and error handling for `bmm`. --- test/test_ops_error_message.py | 66 ++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 23 ++++-- torch_xla/csrc/tensor_methods.cpp | 125 +++++++++++++++++++----------- torch_xla/csrc/tensor_methods.h | 11 ++- 4 files changed, 167 insertions(+), 58 deletions(-) diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index bbb9f4b95b7..14ed559c2e6 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -179,3 +179,69 @@ def test(): callable=test, expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8).""" ) + + def test_bmm_raises_error_on_non_3D_tensor_input(self): + device = torch_xla.device() + a = torch.rand(2, 3, 4, device=device) + b = torch.rand(2, 4, 3, device=device) + + def test_a(): + torch.bmm(a[0], b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test_a, + expect="""bmm(): expected `input` f32[3,4] (a 2D tensor), the 1st input tensor, to be a 3D tensor.""" + ) + + def test_b(): + torch.bmm(a, b[0]) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test_b, + expect="""bmm(): expected `mat2` f32[4,3] (a 2D tensor), the 2nd input tensor, to be a 3D tensor.""" + ) + + def test_bmm_raises_error_on_different_batch_dimension(self): + device = torch_xla.device() + a = torch.rand(4, 3, 4, device=device) + b = torch.rand(2, 4, 3, device=device) + + def test(): + torch.bmm(a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""bmm(): expected the size of the batch dimension (i.e. dimension 0) of `input` f32[4,3,4] (batch dimension size: 4), the 1st input tensor, to be the same as the size of the batch dimension of `mat2` f32[2,4,3] (batch dimension size: 2), the 2nd input tensor.""" + ) + + def test_bmm_raises_error_on_incompatible_shapes(self): + device = torch_xla.device() + a = torch.rand(2, 3, 8, device=device) + b = torch.rand(2, 4, 3, device=device) + + def test(): + torch.bmm(a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""bmm(): cannot apply batch matrix-multiplication to `input` f32[2,3,8], the 1st input tensor, and to `mat2` f32[2,4,3], the 2nd input tensor. Expected the size of dimension 2 of `input` (8) to be equal the size of dimension 1 of `mat2` (4).""" + ) + + def test_baddbmm_raises_error_on_incompatible_shapes(self): + device = torch_xla.device() + input = torch.rand(3, 3, device=device) + a = torch.rand(2, 3, 8, device=device) + b = torch.rand(2, 4, 3, device=device) + + def test(): + torch.baddbmm(input, a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""baddbmm(): cannot apply batch matrix-multiplication to `batch1` f32[2,3,8], the 2nd input tensor, and to `batch2` f32[2,4,3], the 3rd input tensor. Expected the size of dimension 2 of `batch1` (8) to be equal the size of dimension 1 of `batch2` (4).""" + ) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c042d703aa3..75784e21232 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1253,11 +1253,16 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, const at::Scalar& beta, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_batch1, bridge::GetXlaTensor(batch1)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_batch2, bridge::GetXlaTensor(batch2)); - return bridge::AtenFromXlaTensor( + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_batch1, + bridge::GetXlaTensor(batch1)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_batch2, + bridge::GetXlaTensor(batch2)); + XLA_ASSIGN_OR_THROW( + absl_nonnull XLATensorPtr output, tensor_methods::baddbmm(xla_self, xla_batch1, xla_batch2, beta, alpha)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::bernoulli( @@ -1337,9 +1342,13 @@ at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor& self, at::Tensor XLANativeFunctions::bmm(const at::Tensor& self, const at::Tensor& mat2) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat2, bridge::GetXlaTensor(mat2)); - return bridge::AtenFromXlaTensor(tensor_methods::bmm(xla_self, xla_mat2)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_mat2, + bridge::GetXlaTensor(mat2)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::bmm(xla_self, xla_mat2)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::cat(const at::ITensorListRef& tensors, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 21f4db59713..17069b96b06 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -168,6 +168,25 @@ struct MinMaxValues { torch::lazy::Value max; }; +struct InputInfo { + const XLATensorPtr& tensor; + std::string_view name; + int position; + + std::string PositionAsStr() const { + switch (position) { + case 1: + return "1st"; + case 2: + return "2nd"; + case 3: + return "3rd"; + default: + return absl::StrCat(position, "th"); + } + } +}; + torch::lazy::Value MaybeExpand(const torch::lazy::Value& input, const xla::Shape& target_shape) { if (GetXlaShape(input).dimensions() == target_shape.dimensions()) { @@ -193,46 +212,6 @@ MinMaxValues GetMinMaxValues(const XLATensorPtr& tensor, tensor->GetDevice())}; } -void CheckRank(const XLATensorPtr& t, int64_t expected_rank, - const std::string& tag, const std::string& arg_name, - int arg_number) { - int64_t actual_rank = t->shape().get().dimensions_size(); - XLA_CHECK_EQ(actual_rank, expected_rank) - << "Expected " << expected_rank << "-dimensional tensor, but got " - << actual_rank << "-dimensional tensor for " - << "argument #" << arg_number << " '" << arg_name << "'" - << " (while checking arguments for " << tag << ")"; -} - -template -void CheckShapeDimensions(const T& size) { - XLA_CHECK(std::all_of(size.begin(), size.end(), [](int64_t dim) { - return dim >= 0; - })) << "Dimensions cannot be negative numbers"; -} - -void CheckDimensionSize(const XLATensorPtr& t, int64_t dim, - int64_t expected_size, const std::string& tag, - const std::string& arg_name, int arg_number) { - int64_t dim_size = t->size(dim); - XLA_CHECK_EQ(t->size(dim), expected_size) - << "Expected tensor to have size " << expected_size << " at dimension " - << dim << ", but got size " << dim_size << " for " - << "argument #" << arg_number << " '" << arg_name << "'" - << " (while checking arguments for " << tag << ")"; -} - -void CheckBmmDimension(const std::string& tag, const XLATensorPtr& batch1, - const XLATensorPtr& batch2) { - // Consistent with the checks in bmm_out_or_baddbmm_. - CheckRank(batch1, 3, tag, "batch1", 1); - CheckRank(batch2, 3, tag, "batch2", 2); - CheckDimensionSize(batch2, 0, /*batch_size=*/batch1->size(0), tag, "batch2", - 2); - CheckDimensionSize(batch2, 1, /*contraction_size=*/batch1->size(2), tag, - "batch2", 2); -} - std::vector GetExpandDimensions(const xla::Shape& shape, std::vector dimensions) { XLA_CHECK_GE(dimensions.size(), shape.dimensions_size()) << shape; @@ -506,6 +485,51 @@ absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1, return absl::OkStatus(); } +absl::Status CheckInputIs3DTensor(const std::string_view op, + const InputInfo& input) { + int64_t rank = input.tensor->shape().get().dimensions().size(); + if (rank != 3) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + op, "(): expected `", input.name, "` ", + input.tensor->shape().get().ToString(), " (a ", rank, "D tensor), the ", + input.PositionAsStr(), " input tensor, to be a 3D tensor."))); + } + return absl::OkStatus(); +} + +absl::Status CheckBmmInputsAreValid(const std::string_view op, + const InputInfo& input, + const InputInfo& mat2) { + XLA_RETURN_IF_ERROR(CheckInputIs3DTensor(op, input)); + XLA_RETURN_IF_ERROR(CheckInputIs3DTensor(op, mat2)); + + if (input.tensor->size(0) != mat2.tensor->size(0)) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + op, + "(): expected the size of the batch dimension (i.e. dimension 0) of `", + input.name, "` ", input.tensor->shape().get().ToString(), + " (batch dimension size: ", input.tensor->size(0), "), the ", + input.PositionAsStr(), + " input tensor, to be the same as the size of the batch dimension of `", + mat2.name, "` ", mat2.tensor->shape().get().ToString(), + " (batch dimension size: ", mat2.tensor->size(0), "), the ", + mat2.PositionAsStr(), " input tensor."))); + } + if (input.tensor->size(2) != mat2.tensor->size(1)) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + op, "(): cannot apply batch matrix-multiplication to `", input.name, + "` ", input.tensor->shape().get().ToString(), ", the ", + input.PositionAsStr(), " input tensor, and to `", mat2.name, "` ", + mat2.tensor->shape().get().ToString(), ", the ", mat2.PositionAsStr(), + " input tensor. Expected the size of dimension 2 of `", input.name, + "` (", input.tensor->size(2), + ") to be equal the size of dimension 1 of `", mat2.name, "` (", + mat2.tensor->size(1), ")."))); + } + + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1214,10 +1238,14 @@ XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop, count_include_pad)); } -XLATensorPtr baddbmm(const XLATensorPtr& input, const XLATensorPtr& batch1, - const XLATensorPtr& batch2, const at::Scalar& beta, - const at::Scalar& alpha) { - CheckBmmDimension(/*tag=*/"baddbmm", batch1, batch2); +absl::StatusOr baddbmm(const XLATensorPtr& input, + const XLATensorPtr& batch1, + const XLATensorPtr& batch2, + const at::Scalar& beta, + const at::Scalar& alpha) { + XLA_RETURN_IF_ERROR(CheckBmmInputsAreValid( + "baddbmm", {batch1, /* name= */ "batch1", /* position= */ 2}, + {batch2, /* name= */ "batch2", /* position= */ 3})); torch::lazy::Value product_multiplier = XLAGraphExecutor::Get()->GetIrValueForScalar( alpha, batch1->shape().get().element_type(), batch1->GetDevice()); @@ -1267,9 +1295,12 @@ XLATensorPtr bitwise_xor(const XLATensorPtr& input, const XLATensorPtr& other) { input->GetIrValue(), other->GetIrValue())); } -XLATensorPtr bmm(const XLATensorPtr& batch1, const XLATensorPtr& batch2) { - CheckBmmDimension(/*tag=*/"bmm", batch1, batch2); - return matmul(batch1, batch2); +absl::StatusOr bmm(const XLATensorPtr& input, + const XLATensorPtr& mat2) { + XLA_RETURN_IF_ERROR(CheckBmmInputsAreValid( + "bmm", {input, /* name= */ "input", /* position= */ 1}, + {mat2, /* name= */ "mat2", /* position= */ 2})); + return matmul(input, mat2); } std::vector broadcast_tensors( diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index b25b423d49c..1a82cfcc303 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -272,9 +272,11 @@ XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop, std::vector padding, bool ceil_mode, bool count_include_pad); -XLATensorPtr baddbmm(const XLATensorPtr& input, const XLATensorPtr& batch1, - const XLATensorPtr& batch2, const at::Scalar& beta, - const at::Scalar& alpha); +absl::StatusOr baddbmm(const XLATensorPtr& input, + const XLATensorPtr& batch1, + const XLATensorPtr& batch2, + const at::Scalar& beta, + const at::Scalar& alpha); XLATensorPtr bernoulli(const XLATensorPtr& input, double probability); XLATensorPtr bernoulli(const XLATensorPtr& input); @@ -297,7 +299,8 @@ XLATensorPtr bitwise_xor(const XLATensorPtr& input, const XLATensorPtr& other); // Batch matrix multiplication. Both tensors must be 3D, the batch size must // match and the remaining two dimensions must be compatible for matrix // multiplication. -XLATensorPtr bmm(const XLATensorPtr& batch1, const XLATensorPtr& batch2); +absl::StatusOr bmm(const XLATensorPtr& input, + const XLATensorPtr& mat2); // Broadcasts the given tensors according to broadcasting semantics. std::vector broadcast_tensors(