Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions test/test_ops_error_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,69 @@ def test():
callable=test,
expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8)."""
)

def test_bmm_raises_error_on_non_3D_tensor_input(self):
device = torch_xla.device()
a = torch.rand(2, 3, 4, device=device)
b = torch.rand(2, 4, 3, device=device)

def test_a():
torch.bmm(a[0], b)

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=test_a,
expect="""bmm(): expected `input` f32[3,4] (a 2D tensor), the 1st input tensor, to be a 3D tensor."""
)

def test_b():
torch.bmm(a, b[0])

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=test_b,
expect="""bmm(): expected `mat2` f32[4,3] (a 2D tensor), the 2nd input tensor, to be a 3D tensor."""
)

def test_bmm_raises_error_on_different_batch_dimension(self):
device = torch_xla.device()
a = torch.rand(4, 3, 4, device=device)
b = torch.rand(2, 4, 3, device=device)

def test():
torch.bmm(a, b)

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=test,
expect="""bmm(): expected the size of the batch dimension (i.e. dimension 0) of `input` f32[4,3,4] (batch dimension size: 4), the 1st input tensor, to be the same as the size of the batch dimension of `mat2` f32[2,4,3] (batch dimension size: 2), the 2nd input tensor."""
)

def test_bmm_raises_error_on_incompatible_shapes(self):
device = torch_xla.device()
a = torch.rand(2, 3, 8, device=device)
b = torch.rand(2, 4, 3, device=device)

def test():
torch.bmm(a, b)

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=test,
expect="""bmm(): cannot apply batch matrix-multiplication to `input` f32[2,3,8], the 1st input tensor, and to `mat2` f32[2,4,3], the 2nd input tensor. Expected the size of dimension 2 of `input` (8) to be equal the size of dimension 1 of `mat2` (4)."""
)

def test_baddbmm_raises_error_on_incompatible_shapes(self):
device = torch_xla.device()
input = torch.rand(3, 3, device=device)
a = torch.rand(2, 3, 8, device=device)
b = torch.rand(2, 4, 3, device=device)

def test():
torch.baddbmm(input, a, b)

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=test,
expect="""baddbmm(): cannot apply batch matrix-multiplication to `batch1` f32[2,3,8], the 2nd input tensor, and to `batch2` f32[2,4,3], the 3rd input tensor. Expected the size of dimension 2 of `batch1` (8) to be equal the size of dimension 1 of `batch2` (4)."""
)
23 changes: 16 additions & 7 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1253,11 +1253,16 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self,
const at::Scalar& beta,
const at::Scalar& alpha) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_batch1, bridge::GetXlaTensor(batch1));
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_batch2, bridge::GetXlaTensor(batch2));
return bridge::AtenFromXlaTensor(
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_batch1,
bridge::GetXlaTensor(batch1));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_batch2,
bridge::GetXlaTensor(batch2));
XLA_ASSIGN_OR_THROW(
absl_nonnull XLATensorPtr output,
tensor_methods::baddbmm(xla_self, xla_batch1, xla_batch2, beta, alpha));
return bridge::AtenFromXlaTensor(std::move(output));
}

at::Tensor XLANativeFunctions::bernoulli(
Expand Down Expand Up @@ -1337,9 +1342,13 @@ at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor& self,
at::Tensor XLANativeFunctions::bmm(const at::Tensor& self,
const at::Tensor& mat2) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat2, bridge::GetXlaTensor(mat2));
return bridge::AtenFromXlaTensor(tensor_methods::bmm(xla_self, xla_mat2));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_mat2,
bridge::GetXlaTensor(mat2));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
tensor_methods::bmm(xla_self, xla_mat2));
return bridge::AtenFromXlaTensor(std::move(output));
}

at::Tensor XLANativeFunctions::cat(const at::ITensorListRef& tensors,
Expand Down
125 changes: 78 additions & 47 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,25 @@ struct MinMaxValues {
torch::lazy::Value max;
};

struct InputInfo {
const XLATensorPtr& tensor;
std::string_view name;
int position;

std::string PositionAsStr() const {
switch (position) {
case 1:
return "1st";
case 2:
return "2nd";
case 3:
return "3rd";
default:
return absl::StrCat(position, "th");
}
}
};

torch::lazy::Value MaybeExpand(const torch::lazy::Value& input,
const xla::Shape& target_shape) {
if (GetXlaShape(input).dimensions() == target_shape.dimensions()) {
Expand All @@ -193,46 +212,6 @@ MinMaxValues GetMinMaxValues(const XLATensorPtr& tensor,
tensor->GetDevice())};
}

void CheckRank(const XLATensorPtr& t, int64_t expected_rank,
const std::string& tag, const std::string& arg_name,
int arg_number) {
int64_t actual_rank = t->shape().get().dimensions_size();
XLA_CHECK_EQ(actual_rank, expected_rank)
<< "Expected " << expected_rank << "-dimensional tensor, but got "
<< actual_rank << "-dimensional tensor for "
<< "argument #" << arg_number << " '" << arg_name << "'"
<< " (while checking arguments for " << tag << ")";
}

template <typename T>
void CheckShapeDimensions(const T& size) {
XLA_CHECK(std::all_of(size.begin(), size.end(), [](int64_t dim) {
return dim >= 0;
})) << "Dimensions cannot be negative numbers";
}

void CheckDimensionSize(const XLATensorPtr& t, int64_t dim,
int64_t expected_size, const std::string& tag,
const std::string& arg_name, int arg_number) {
int64_t dim_size = t->size(dim);
XLA_CHECK_EQ(t->size(dim), expected_size)
<< "Expected tensor to have size " << expected_size << " at dimension "
<< dim << ", but got size " << dim_size << " for "
<< "argument #" << arg_number << " '" << arg_name << "'"
<< " (while checking arguments for " << tag << ")";
}

void CheckBmmDimension(const std::string& tag, const XLATensorPtr& batch1,
const XLATensorPtr& batch2) {
// Consistent with the checks in bmm_out_or_baddbmm_.
CheckRank(batch1, 3, tag, "batch1", 1);
CheckRank(batch2, 3, tag, "batch2", 2);
CheckDimensionSize(batch2, 0, /*batch_size=*/batch1->size(0), tag, "batch2",
2);
CheckDimensionSize(batch2, 1, /*contraction_size=*/batch1->size(2), tag,
"batch2", 2);
}

std::vector<int64_t> GetExpandDimensions(const xla::Shape& shape,
std::vector<int64_t> dimensions) {
XLA_CHECK_GE(dimensions.size(), shape.dimensions_size()) << shape;
Expand Down Expand Up @@ -506,6 +485,51 @@ absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1,
return absl::OkStatus();
}

absl::Status CheckInputIs3DTensor(const std::string_view op,
const InputInfo& input) {
int64_t rank = input.tensor->shape().get().dimensions().size();
if (rank != 3) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
op, "(): expected `", input.name, "` ",
input.tensor->shape().get().ToString(), " (a ", rank, "D tensor), the ",
input.PositionAsStr(), " input tensor, to be a 3D tensor.")));
}
return absl::OkStatus();
}

absl::Status CheckBmmInputsAreValid(const std::string_view op,
const InputInfo& input,
const InputInfo& mat2) {
XLA_RETURN_IF_ERROR(CheckInputIs3DTensor(op, input));
XLA_RETURN_IF_ERROR(CheckInputIs3DTensor(op, mat2));

if (input.tensor->size(0) != mat2.tensor->size(0)) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
op,
"(): expected the size of the batch dimension (i.e. dimension 0) of `",
input.name, "` ", input.tensor->shape().get().ToString(),
" (batch dimension size: ", input.tensor->size(0), "), the ",
input.PositionAsStr(),
" input tensor, to be the same as the size of the batch dimension of `",
mat2.name, "` ", mat2.tensor->shape().get().ToString(),
" (batch dimension size: ", mat2.tensor->size(0), "), the ",
mat2.PositionAsStr(), " input tensor.")));
}
if (input.tensor->size(2) != mat2.tensor->size(1)) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
op, "(): cannot apply batch matrix-multiplication to `", input.name,
"` ", input.tensor->shape().get().ToString(), ", the ",
input.PositionAsStr(), " input tensor, and to `", mat2.name, "` ",
mat2.tensor->shape().get().ToString(), ", the ", mat2.PositionAsStr(),
" input tensor. Expected the size of dimension 2 of `", input.name,
"` (", input.tensor->size(2),
") to be equal the size of dimension 1 of `", mat2.name, "` (",
mat2.tensor->size(1), ").")));
}

return absl::OkStatus();
}

} // namespace

//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1214,10 +1238,14 @@ XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop,
count_include_pad));
}

XLATensorPtr baddbmm(const XLATensorPtr& input, const XLATensorPtr& batch1,
const XLATensorPtr& batch2, const at::Scalar& beta,
const at::Scalar& alpha) {
CheckBmmDimension(/*tag=*/"baddbmm", batch1, batch2);
absl::StatusOr<absl_nonnull XLATensorPtr> baddbmm(const XLATensorPtr& input,
const XLATensorPtr& batch1,
const XLATensorPtr& batch2,
const at::Scalar& beta,
const at::Scalar& alpha) {
XLA_RETURN_IF_ERROR(CheckBmmInputsAreValid(
"baddbmm", {batch1, /* name= */ "batch1", /* position= */ 2},
{batch2, /* name= */ "batch2", /* position= */ 3}));
torch::lazy::Value product_multiplier =
XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha, batch1->shape().get().element_type(), batch1->GetDevice());
Expand Down Expand Up @@ -1267,9 +1295,12 @@ XLATensorPtr bitwise_xor(const XLATensorPtr& input, const XLATensorPtr& other) {
input->GetIrValue(), other->GetIrValue()));
}

XLATensorPtr bmm(const XLATensorPtr& batch1, const XLATensorPtr& batch2) {
CheckBmmDimension(/*tag=*/"bmm", batch1, batch2);
return matmul(batch1, batch2);
absl::StatusOr<absl_nonnull XLATensorPtr> bmm(const XLATensorPtr& input,
const XLATensorPtr& mat2) {
XLA_RETURN_IF_ERROR(CheckBmmInputsAreValid(
"bmm", {input, /* name= */ "input", /* position= */ 1},
{mat2, /* name= */ "mat2", /* position= */ 2}));
return matmul(input, mat2);
}

std::vector<XLATensorPtr> broadcast_tensors(
Expand Down
11 changes: 7 additions & 4 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,11 @@ XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop,
std::vector<int64_t> padding, bool ceil_mode,
bool count_include_pad);

XLATensorPtr baddbmm(const XLATensorPtr& input, const XLATensorPtr& batch1,
const XLATensorPtr& batch2, const at::Scalar& beta,
const at::Scalar& alpha);
absl::StatusOr<absl_nonnull XLATensorPtr> baddbmm(const XLATensorPtr& input,
const XLATensorPtr& batch1,
const XLATensorPtr& batch2,
const at::Scalar& beta,
const at::Scalar& alpha);

XLATensorPtr bernoulli(const XLATensorPtr& input, double probability);
XLATensorPtr bernoulli(const XLATensorPtr& input);
Expand All @@ -297,7 +299,8 @@ XLATensorPtr bitwise_xor(const XLATensorPtr& input, const XLATensorPtr& other);
// Batch matrix multiplication. Both tensors must be 3D, the batch size must
// match and the remaining two dimensions must be compatible for matrix
// multiplication.
XLATensorPtr bmm(const XLATensorPtr& batch1, const XLATensorPtr& batch2);
absl::StatusOr<absl_nonnull XLATensorPtr> bmm(const XLATensorPtr& input,
const XLATensorPtr& mat2);

// Broadcasts the given tensors according to broadcasting semantics.
std::vector<XLATensorPtr> broadcast_tensors(
Expand Down