Skip to content

Commit 9ed0b84

Browse files
committed
align computation with triton
1 parent edc6cc0 commit 9ed0b84

File tree

1 file changed

+32
-42
lines changed

1 file changed

+32
-42
lines changed

lib/gc/Transforms/FlashAttentionConversion.cpp

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,12 @@ struct MHAToFlashAttention
331331
ValueRange{curSumSlice, rescaledPrevSumSlice},
332332
ValueRange{reducedShapeOut})
333333
.getResult(0);
334+
Value newSumSliceRecip =
335+
rewriter
336+
.create<linalg::ReciprocalOp>(loc, reducedShapeOut.getType(),
337+
ValueRange{newSumSlice},
338+
ValueRange{reducedShapeOut})
339+
.getResult(0);
334340
SmallVector<int64_t> VShape{cfg.RowBlockSize, headDim};
335341
Value VShapeOut = rewriter.create<tensor::EmptyOp>(loc, VShape, dtype);
336342
Value matmulVOutFilled =
@@ -341,38 +347,40 @@ struct MHAToFlashAttention
341347
ValueRange{PSlice, collapsedVSlice},
342348
ValueRange{matmulVOutFilled})
343349
.getResult(0);
344-
Value expMaxDiffBroadcasted =
350+
Value newSumSliceRecipBroadcasted =
345351
rewriter
346-
.create<linalg::BroadcastOp>(loc, expMaxDiff, VShapeOut,
352+
.create<linalg::BroadcastOp>(loc, newSumSliceRecip, VShapeOut,
347353
SmallVector<int64_t>{1})
348354
.getResults()[0];
349-
Value expMaxDiffBroadcastedEps =
355+
Value rescaledPrevSumSliceBroadcasted =
350356
rewriter
351-
.create<linalg::GenericOp>(
352-
loc, VShapeOut.getType(), ValueRange{expMaxDiffBroadcasted},
353-
ValueRange{VShapeOut}, indexingMaps,
354-
SmallVector<utils::IteratorType>(2,
355-
utils::IteratorType::parallel),
356-
[&](OpBuilder &nestedBuilder, Location nestedLoc,
357-
ValueRange args) {
358-
Value eps = nestedBuilder.create<arith::ConstantOp>(
359-
loc, nestedBuilder.getFloatAttr(dtype, 1e-9));
360-
Value added =
361-
nestedBuilder.create<arith::AddFOp>(loc, args[0], eps);
362-
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
363-
})
357+
.create<linalg::BroadcastOp>(loc, rescaledPrevSumSlice, VShapeOut,
358+
SmallVector<int64_t>{1})
359+
.getResults()[0];
360+
Value rescaledMatmulV =
361+
rewriter
362+
.create<linalg::MulOp>(
363+
loc, matmulVOutFilled.getType(),
364+
ValueRange{matmulV, newSumSliceRecipBroadcasted},
365+
ValueRange{matmulVOutFilled})
366+
.getResult(0);
367+
Value sumSliceQuotient =
368+
rewriter
369+
.create<linalg::MulOp>(loc, matmulVOutFilled.getType(),
370+
ValueRange{rescaledPrevSumSliceBroadcasted,
371+
newSumSliceRecipBroadcasted},
372+
ValueRange{matmulVOutFilled})
364373
.getResult(0);
365374
Value rescaledOSlice =
366375
rewriter
367-
.create<linalg::DivOp>(
368-
loc, VShapeOut.getType(),
369-
ValueRange{prevOSlice, expMaxDiffBroadcastedEps},
370-
ValueRange{VShapeOut})
376+
.create<linalg::MulOp>(loc, matmulVOutFilled.getType(),
377+
ValueRange{prevOSlice, sumSliceQuotient},
378+
ValueRange{matmulVOutFilled})
371379
.getResult(0);
372380
Value newOSlice =
373381
rewriter
374382
.create<linalg::AddOp>(loc, VShapeOut.getType(),
375-
ValueRange{rescaledOSlice, matmulV},
383+
ValueRange{rescaledOSlice, rescaledMatmulV},
376384
ValueRange{VShapeOut})
377385
.getResult(0);
378386
// yield all the results of the innermost loop.
@@ -381,25 +389,7 @@ struct MHAToFlashAttention
381389
// yield rowBlockLoop results
382390
rewriter.setInsertionPointToEnd(rowBlockLoop.getBody());
383391
auto innermostLoopResults = columnBlockLoop->getResults();
384-
Value OSliceFinal = innermostLoopResults[0],
385-
sumSliceFinal = innermostLoopResults[2];
386-
Value sliceShapeOut =
387-
rewriter.create<tensor::EmptyOp>(loc, reducedShape, dtype);
388-
Value broadcastedSliceShapeOut =
389-
rewriter.create<tensor::EmptyOp>(loc, VShape, dtype);
390-
Value sumSliceFinalBroadcasted =
391-
rewriter
392-
.create<linalg::BroadcastOp>(loc, sumSliceFinal,
393-
broadcastedSliceShapeOut,
394-
SmallVector<int64_t>{1})
395-
.getResults()[0];
396-
Value rescaledOSliceFinal =
397-
rewriter
398-
.create<linalg::DivOp>(
399-
loc, broadcastedSliceShapeOut.getType(),
400-
ValueRange{OSliceFinal, sumSliceFinalBroadcasted},
401-
ValueRange{broadcastedSliceShapeOut})
402-
.getResult(0);
392+
Value OSliceFinal = innermostLoopResults[0];
403393
SmallVector<OpFoldResult> outputOffsets;
404394
outputOffsets.push_back(getAsOpFoldResult(ivs[0]));
405395
outputOffsets.push_back(getAsOpFoldResult(ivs[1]));
@@ -409,8 +399,8 @@ struct MHAToFlashAttention
409399
outputSizes[2] = rewriter.getIndexAttr(cfg.RowBlockSize);
410400
outputSizes[3] = rewriter.getIndexAttr(headDim);
411401
Value insertedRescaledOSlice = rewriter.create<tensor::InsertSliceOp>(
412-
loc, rescaledOSliceFinal, rowBlockLoop.getRegionIterArgs()[0],
413-
outputOffsets, outputSizes, strides);
402+
loc, OSliceFinal, rowBlockLoop.getRegionIterArgs()[0], outputOffsets,
403+
outputSizes, strides);
414404
rewriter.create<scf::YieldOp>(loc, ValueRange{insertedRescaledOSlice});
415405
// Add the scf.yield operations for all the outer loops.
416406
for (auto [outerLoop, innerLoop] :

0 commit comments

Comments
 (0)