Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,11 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {

// Broadcast the batch dimensions of both the matrices.
Value broadcastedLhs, broadcastedRhs;
// TODO: Improve usage of static shape information.
SmallVector<int64_t> lhsTargetShape(lhsBroadcastToShape.size(),
ShapedType::kDynamic);
SmallVector<int64_t> lhsTargetShape =
llvm::to_vector(llvm::map_range(lhsBroadcastToShape, [](Value v) {
return getConstantIntValue(v).value_or(ShapedType::kDynamic);
}));

auto lhsBroadcastType = RankedTensorType::get(
lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding());
if (failed(torch_to_linalg::broadcastToGivenShape(
Expand All @@ -516,8 +518,10 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
return rewriter.notifyMatchFailure(
op, "unable to perform broadcast operation");
}
SmallVector<int64_t> rhsTargetShape(rhsBroadcastToShape.size(),
ShapedType::kDynamic);
SmallVector<int64_t> rhsTargetShape =
llvm::to_vector(llvm::map_range(rhsBroadcastToShape, [](Value v) {
return getConstantIntValue(v).value_or(ShapedType::kDynamic);
}));
auto rhsBroadcastType = RankedTensorType::get(
rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding());
if (failed(torch_to_linalg::broadcastToGivenShape(
Expand Down
64 changes: 64 additions & 0 deletions test/Conversion/TorchToLinalg/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,70 @@ func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch

// -----

// CHECK-LABEL: func.func @torch.aten.matmul.4d
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,2,32,400],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32>
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
Comment on lines +51 to +57
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bit flaky and adds a bunch of checks we don't actually care about in the logic of the conversion.

I'd just verify the relevant ops. E.g. check the collapse, expand, and linalg batch mm ops. The specific values for fill ops etc. aren't super important to check here.

// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_10:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_11:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_12:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_13:.*]] = arith.constant 400 : index
// CHECK: %[[VAL_14:.*]] = arith.constant 3 : index
// CHECK: %[[VAL_15:.*]] = arith.constant 32 : index
// CHECK: %[[VAL_16:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_17:.*]] = arith.constant 32 : index
// CHECK: %[[VAL_18:.*]] = arith.constant 3 : index
// CHECK: %[[VAL_19:.*]] = arith.constant 400 : index
// CHECK: %[[VAL_20:.*]] = arith.constant 32 : i64
// CHECK: %[[VAL_21:.*]] = arith.constant 32 : i64
// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_21]] : i64
// CHECK: cf.assert %[[VAL_22]], "mismatching contracting dimension"
// CHECK: %[[VAL_23:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_24:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_25:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_26:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_27:.*]] = arith.constant 400 : i64
// CHECK: %[[VAL_28:.*]] = arith.constant 32 : i64
// CHECK: %[[VAL_29:.*]] = arith.constant 32 : i64
// CHECK: %[[VAL_30:.*]] = arith.constant 400 : i64
// CHECK: %[[VAL_31:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_32:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_33:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_34:.*]] = tensor.empty() : tensor<1x2x400x32xf32>
// CHECK: %[[VAL_35:.*]] = tensor.cast %[[VAL_1]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32>
// CHECK: %[[VAL_36:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_37:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_38:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_39:.*]] = tensor.empty() : tensor<1x2x32x400xf32>
// CHECK: %[[VAL_40:.*]] = tensor.cast %[[VAL_0]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32>
// CHECK: %[[VAL_41:.*]] = tensor.collapse_shape %[[VAL_35]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32>
// CHECK: %[[VAL_42:.*]] = tensor.collapse_shape %[[VAL_40]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32>
// CHECK: %[[VAL_43:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_44:.*]] = tensor.empty() : tensor<2x400x400xf32>
// CHECK: %[[VAL_45:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_46:.*]] = linalg.fill ins(%[[VAL_45]] : f32) outs(%[[VAL_44]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32>
// CHECK: %[[VAL_47:.*]] = linalg.batch_matmul ins(%[[VAL_41]], %[[VAL_42]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%[[VAL_46]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32>
// CHECK: %[[VAL_48:.*]] = tensor.expand_shape %[[VAL_47]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32>
// CHECK: %[[VAL_49:.*]] = tensor.cast %[[VAL_48]] : tensor<1x2x400x400xf32> to tensor<1x2x400x400xf32>
// CHECK: %[[VAL_50:.*]] = torch_c.from_builtin_tensor %[[VAL_49]] : tensor<1x2x400x400xf32> -> !torch.vtensor<[1,2,400,400],f32>
// CHECK: return %[[VAL_50]] : !torch.vtensor<[1,2,400,400],f32>
// CHECK: }
func.func @torch.aten.matmul.4d(%arg0: !torch.vtensor<[1,2,32,400],f32>, %arg1: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> {
%0 = torch.aten.matmul %arg1, %arg0 : !torch.vtensor<[1,2,400,32],f32>, !torch.vtensor<[1,2,32,400],f32> -> !torch.vtensor<[1,2,400,400],f32>
return %0 : !torch.vtensor<[1,2,400,400],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.mm$basic_strict(
// CHECK-NOT: assert
func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>
Expand Down
Loading