Skip to content

Commit

Permalink
[MLIR][TORCH] Add support for enable_gqa flag in SDPA op (llvm#3950)
Browse files Browse the repository at this point in the history
Signed-off-by: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored Feb 5, 2025
1 parent 7cea07c commit 25aa0c6
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 5 deletions.
128 changes: 123 additions & 5 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,54 @@ static FailureOr<SmallVector<Value>> createTMTensorTopkOp(
return SmallVector<Value>(topkOp.getResults());
}

static FailureOr<Value>
repeatTensorElementsForDim(Operation *op, ConversionPatternRewriter &rewriter,
Type resType, Value self, int64_t repeats,
int64_t dim) {
Location loc = op->getLoc();
auto context = op->getContext();
auto selfTy = cast<BaseTensorType>(self.getType());

int64_t inputRank = selfTy.getSizes().size();
dim = toPositiveDim(dim, inputRank);
Value dimValue =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
Value dimValuePlusOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim + 1));

auto unsqueezedInfo = unsqueezeTensor(rewriter, op, self, dimValuePlusOne);
if (failed(unsqueezedInfo))
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor op");
self = *unsqueezedInfo;

Value constMinusOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
SmallVector<Value> expandShapeValueList(inputRank + 1, constMinusOne);
expandShapeValueList[dim + 1] =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(repeats));
Value expandShapeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), expandShapeValueList);

SmallVector<int64_t> expandShape(inputRank + 1);
for (int64_t i = 0; i <= dim; i++) {
expandShape[i] = selfTy.getSizes()[i];
}
expandShape[dim + 1] = repeats;
for (int64_t i = dim + 1; i < inputRank; i++) {
expandShape[i + 1] = selfTy.getSizes()[i];
}

BaseTensorType expandTy =
rewriter.getType<ValueTensorType>(expandShape, selfTy.getOptionalDtype());
Value expandSelf =
rewriter.create<AtenBroadcastToOp>(loc, expandTy, self, expandShapeList);

Value result = rewriter.create<PrimsCollapseOp>(loc, resType, expandSelf,
dimValue, dimValuePlusOne);
return result;
}

namespace {
template <typename AtenOpT>
class ConvertAtenScatterOp : public OpConversionPattern<AtenOpT> {
Expand Down Expand Up @@ -1651,6 +1699,65 @@ class ConvertAtenScaledDotProductAttentionOp
: public OpConversionPattern<AtenScaledDotProductAttentionOp> {
public:
using OpConversionPattern::OpConversionPattern;

static LogicalResult
preProcessGroupQueryAttentionInput(AtenScaledDotProductAttentionOp op,
ConversionPatternRewriter &rewriter,
const TypeConverter *typeConverter,
Value query, Value &key, Value &value) {
auto queryTy = cast<ShapedType>(query.getType());
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());

int64_t rank = queryTy.getRank();

int64_t qNumHeads = queryTy.getDimSize(rank - 3);
int64_t kNumHeads = valueTy.getDimSize(rank - 3);
int64_t vNumHeads = keyTy.getDimSize(rank - 3);

if (llvm::any_of(llvm::ArrayRef<int64_t>{qNumHeads, kNumHeads, vNumHeads},
[](int64_t d) { return d == Torch::kUnknownSize; })) {
return llvm::failure();
}

if (llvm::all_equal(
llvm::ArrayRef<int64_t>{qNumHeads, kNumHeads, vNumHeads}))
return llvm::success();

if ((qNumHeads % kNumHeads) && (qNumHeads % vNumHeads))
return llvm::failure();

int64_t repeatKeyShape = qNumHeads / kNumHeads;
int64_t repeatValueShape = qNumHeads / vNumHeads;

Location loc = op.getLoc();
FailureOr<Value> keyRepeated = repeatTensorElementsForDim(
op.getOperation(), rewriter, /*resType=*/op.getQuery().getType(),
op.getKey(),
/*repeats=*/repeatKeyShape, /*dim=*/rank - 3);
if (failed(keyRepeated))
return rewriter.notifyMatchFailure(
loc, "Failed to repeat the tensor elements for key.");

FailureOr<Value> valueRepeated = repeatTensorElementsForDim(
op.getOperation(), rewriter, /*resType=*/op.getQuery().getType(),
op.getValue(),
/*repeats=*/repeatValueShape, /*dim=*/rank - 3);
if (failed(valueRepeated))
return rewriter.notifyMatchFailure(
loc, "Failed to repeat the tensor elements for value.");

key = typeConverter->materializeTargetConversion(
rewriter, loc,
typeConverter->convertType(keyRepeated.value().getType()),
keyRepeated.value());
value = typeConverter->materializeTargetConversion(
rewriter, loc,
typeConverter->convertType(valueRepeated.value().getType()),
valueRepeated.value());
return success();
}

LogicalResult
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1795,11 +1902,6 @@ class ConvertAtenScaledDotProductAttentionOp
scaleFloat != 1.0)
return rewriter.notifyMatchFailure(loc, "only default scale supported");
}
bool isGQAEnabled;
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)) ||
isGQAEnabled)
return rewriter.notifyMatchFailure(
loc, "grouped query attention not supported");

if (queryTy.getRank() != valueTy.getRank() ||
queryTy.getRank() != keyTy.getRank())
Expand All @@ -1808,6 +1910,22 @@ class ConvertAtenScaledDotProductAttentionOp
if (queryTy.getRank() < 3)
return rewriter.notifyMatchFailure(op, "missing batch dimension");

bool isGQAEnabled;
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)))
return rewriter.notifyMatchFailure(
loc, "Expected enable_gqa flag to be constant bool");

// For the cases when `enable_gqa` flag is set to true, we have to
// pre-process the inputs specifically key and value by repeating the
// elements for the head dim.
// The reference code is available here:
// https://github.com/pytorch/pytorch/pull/132689/files#diff-e726853e9795dfb6c74ab1e10945f5d5f24540eb7bc633e5c76f69bc258f24d6R612
if (enableGQA) {
if (failed(preProcessGroupQueryAttentionInput(
op, rewriter, getTypeConverter(), query, key, value)))
return failure();
}

llvm::SmallVector<ReassociationIndices, 3> reassociation(3);
for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i)
reassociation.front().push_back(i);
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,7 @@
"BernoulliFloatModule_basic",
"UniformModule_basic",
"UniformStaticShapeModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand Down Expand Up @@ -3252,6 +3253,7 @@
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"ScaledDotProductAttentionGQAModule_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down Expand Up @@ -3764,6 +3766,7 @@
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
}

ONNX_TOSA_CRASHING_SET = {
Expand Down
27 changes: 27 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5742,6 +5742,33 @@ def ScaledDotProductAttentionBoolMaskModule_basic(module, tu: TestUtils):
module.forward(query, key, value, mask)


class ScaledDotProductAttentionGQAModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([4, 32, 3, 8], torch.float32, True),
([4, 8, 3, 8], torch.float32, True),
([4, 8, 3, 8], torch.float32, True),
]
)
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(
query, key, value, enable_gqa=True
)


@register_test_case(module_factory=lambda: ScaledDotProductAttentionGQAModule())
def ScaledDotProductAttentionGQAModule_basic(module, tu: TestUtils):
query = torch.randn(4, 32, 3, 8, dtype=torch.float32)
key = torch.randn(4, 8, 3, 8, dtype=torch.float32)
value = torch.randn(4, 8, 3, 8, dtype=torch.float32)
module.forward(query, key, value)


# ==============================================================================


Expand Down

0 comments on commit 25aa0c6

Please sign in to comment.