1111#include < iterator>
1212
1313#include " absl/log/absl_check.h"
14+ #include " absl/status/status.h"
1415#include " absl/strings/str_cat.h"
1516#include " absl/strings/str_split.h"
1617#include " torch_xla/csrc/LazyIr.h"
@@ -453,14 +454,14 @@ absl::Status CheckGatherRanksAreEqual(const XLATensorPtr& input,
453454 return absl::OkStatus ();
454455}
455456
456- // Checks that all index dimensions are smaller or equal to those of input,
457- // except on dimension canonical_dim.
458- absl::Status CheckGatherDimensionsAreCompatible (const XLATensorPtr& input,
459- const XLATensorPtr& index,
460- int64_t canonical_dim) {
457+ // Checks that all index dimension sizes are smaller or equal to those of
458+ // input, except on dimension canonical_dim.
459+ absl::Status CheckGatherSizesAreCompatible (const XLATensorPtr& input,
460+ const XLATensorPtr& index,
461+ int64_t canonical_dim) {
461462 // Dimensions that fail the "smaller or equal" condition.
462463 std::vector<int64_t > bad_dims;
463- for (int64_t dim = 0 ; dim < input->shape ().get ().dimensions_size (); dim++) {
464+ for (int64_t dim = 0 ; dim < input->shape ().get ().dimensions (). size (); dim++) {
464465 if (dim != canonical_dim && input->size (dim) < index->size (dim)) {
465466 bad_dims.push_back (dim);
466467 }
@@ -478,6 +479,33 @@ absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input,
478479 return absl::OkStatus ();
479480}
480481
482+ absl::Status CheckMMInputIsMatrix (const XLATensorPtr& mat,
483+ const std::string_view arg) {
484+ xla::Shape shape = mat->shape ();
485+ if (shape.dimensions ().size () != 2 ) {
486+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (
487+ absl::StrCat (" mm(): expected the " , arg, " input tensor " ,
488+ shape.ToString (), " to be a matrix (i.e. a 2D tensor)." )));
489+ }
490+ return absl::OkStatus ();
491+ }
492+
493+ absl::Status CheckMMMatrixSizesAreCompatible (const XLATensorPtr& mat1,
494+ const XLATensorPtr& mat2) {
495+ xla::Shape shape1 = mat1->shape ();
496+ xla::Shape shape2 = mat2->shape ();
497+ if (shape1.dimensions (1 ) != shape2.dimensions (0 )) {
498+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
499+ " mm(): cannot matrix-multiply tensors " , shape1.ToString (), " and " ,
500+ shape2.ToString (),
501+ " . Expected the size of dimension 1 of the first input tensor (" ,
502+ shape1.dimensions (1 ),
503+ " ) to be equal the size of dimension 0 of the second input tensor (" ,
504+ shape2.dimensions (0 ), " )." )));
505+ }
506+ return absl::OkStatus ();
507+ }
508+
481509} // namespace
482510
483511// ////////////////////////////////////////////////////////////////////////////
@@ -1844,7 +1872,7 @@ absl::StatusOr<absl_nonnull XLATensorPtr> gather(const XLATensorPtr& input,
18441872 dim, input->shape ().get ().dimensions_size ());
18451873 XLA_RETURN_IF_ERROR (CheckGatherRanksAreEqual (input, index));
18461874 XLA_RETURN_IF_ERROR (
1847- CheckGatherDimensionsAreCompatible (input, index, canonical_dim));
1875+ CheckGatherSizesAreCompatible (input, index, canonical_dim));
18481876 return input->CreateFrom (torch_xla::MakeNode<Gather>(
18491877 input->GetIrValue (), canonical_dim, index->GetIrValue ()));
18501878}
@@ -2349,7 +2377,11 @@ XLATensorPtr mish(const XLATensorPtr& input) {
23492377 tensor_ops::Softplus (input, 1 , 20 )->GetIrValue ()));
23502378}
23512379
2352- XLATensorPtr mm (const XLATensorPtr& input, const XLATensorPtr& weight) {
2380+ absl::StatusOr<XLATensorPtr> mm (const XLATensorPtr& input,
2381+ const XLATensorPtr& weight) {
2382+ XLA_RETURN_IF_ERROR (CheckMMInputIsMatrix (input, " first" ));
2383+ XLA_RETURN_IF_ERROR (CheckMMInputIsMatrix (weight, " second" ));
2384+ XLA_RETURN_IF_ERROR (CheckMMMatrixSizesAreCompatible (input, weight));
23532385 return input->CreateFrom (Dot (input->GetIrValue (), weight->GetIrValue ()));
23542386}
23552387
0 commit comments