@@ -331,6 +331,12 @@ struct MHAToFlashAttention
331
331
ValueRange{curSumSlice, rescaledPrevSumSlice},
332
332
ValueRange{reducedShapeOut})
333
333
.getResult (0 );
334
+ Value newSumSliceRecip =
335
+ rewriter
336
+ .create <linalg::ReciprocalOp>(loc, reducedShapeOut.getType (),
337
+ ValueRange{newSumSlice},
338
+ ValueRange{reducedShapeOut})
339
+ .getResult (0 );
334
340
SmallVector<int64_t > VShape{cfg.RowBlockSize , headDim};
335
341
Value VShapeOut = rewriter.create <tensor::EmptyOp>(loc, VShape, dtype);
336
342
Value matmulVOutFilled =
@@ -341,38 +347,40 @@ struct MHAToFlashAttention
341
347
ValueRange{PSlice, collapsedVSlice},
342
348
ValueRange{matmulVOutFilled})
343
349
.getResult (0 );
344
- Value expMaxDiffBroadcasted =
350
+ Value newSumSliceRecipBroadcasted =
345
351
rewriter
346
- .create <linalg::BroadcastOp>(loc, expMaxDiff , VShapeOut,
352
+ .create <linalg::BroadcastOp>(loc, newSumSliceRecip , VShapeOut,
347
353
SmallVector<int64_t >{1 })
348
354
.getResults ()[0 ];
349
- Value expMaxDiffBroadcastedEps =
355
+ Value rescaledPrevSumSliceBroadcasted =
350
356
rewriter
351
- .create <linalg::GenericOp>(
352
- loc, VShapeOut.getType (), ValueRange{expMaxDiffBroadcasted},
353
- ValueRange{VShapeOut}, indexingMaps,
354
- SmallVector<utils::IteratorType>(2 ,
355
- utils::IteratorType::parallel),
356
- [&](OpBuilder &nestedBuilder, Location nestedLoc,
357
- ValueRange args) {
358
- Value eps = nestedBuilder.create <arith::ConstantOp>(
359
- loc, nestedBuilder.getFloatAttr (dtype, 1e-9 ));
360
- Value added =
361
- nestedBuilder.create <arith::AddFOp>(loc, args[0 ], eps);
362
- nestedBuilder.create <linalg::YieldOp>(nestedLoc, added);
363
- })
357
+ .create <linalg::BroadcastOp>(loc, rescaledPrevSumSlice, VShapeOut,
358
+ SmallVector<int64_t >{1 })
359
+ .getResults ()[0 ];
360
+ Value rescaledMatmulV =
361
+ rewriter
362
+ .create <linalg::MulOp>(
363
+ loc, matmulVOutFilled.getType (),
364
+ ValueRange{matmulV, newSumSliceRecipBroadcasted},
365
+ ValueRange{matmulVOutFilled})
366
+ .getResult (0 );
367
+ Value sumSliceQuotient =
368
+ rewriter
369
+ .create <linalg::MulOp>(loc, matmulVOutFilled.getType (),
370
+ ValueRange{rescaledPrevSumSliceBroadcasted,
371
+ newSumSliceRecipBroadcasted},
372
+ ValueRange{matmulVOutFilled})
364
373
.getResult (0 );
365
374
Value rescaledOSlice =
366
375
rewriter
367
- .create <linalg::DivOp>(
368
- loc, VShapeOut.getType (),
369
- ValueRange{prevOSlice, expMaxDiffBroadcastedEps},
370
- ValueRange{VShapeOut})
376
+ .create <linalg::MulOp>(loc, matmulVOutFilled.getType (),
377
+ ValueRange{prevOSlice, sumSliceQuotient},
378
+ ValueRange{matmulVOutFilled})
371
379
.getResult (0 );
372
380
Value newOSlice =
373
381
rewriter
374
382
.create <linalg::AddOp>(loc, VShapeOut.getType (),
375
- ValueRange{rescaledOSlice, matmulV },
383
+ ValueRange{rescaledOSlice, rescaledMatmulV },
376
384
ValueRange{VShapeOut})
377
385
.getResult (0 );
378
386
// yield all the results of the innermost loop.
@@ -381,25 +389,7 @@ struct MHAToFlashAttention
381
389
// yield rowBlockLoop results
382
390
rewriter.setInsertionPointToEnd (rowBlockLoop.getBody ());
383
391
auto innermostLoopResults = columnBlockLoop->getResults ();
384
- Value OSliceFinal = innermostLoopResults[0 ],
385
- sumSliceFinal = innermostLoopResults[2 ];
386
- Value sliceShapeOut =
387
- rewriter.create <tensor::EmptyOp>(loc, reducedShape, dtype);
388
- Value broadcastedSliceShapeOut =
389
- rewriter.create <tensor::EmptyOp>(loc, VShape, dtype);
390
- Value sumSliceFinalBroadcasted =
391
- rewriter
392
- .create <linalg::BroadcastOp>(loc, sumSliceFinal,
393
- broadcastedSliceShapeOut,
394
- SmallVector<int64_t >{1 })
395
- .getResults ()[0 ];
396
- Value rescaledOSliceFinal =
397
- rewriter
398
- .create <linalg::DivOp>(
399
- loc, broadcastedSliceShapeOut.getType (),
400
- ValueRange{OSliceFinal, sumSliceFinalBroadcasted},
401
- ValueRange{broadcastedSliceShapeOut})
402
- .getResult (0 );
392
+ Value OSliceFinal = innermostLoopResults[0 ];
403
393
SmallVector<OpFoldResult> outputOffsets;
404
394
outputOffsets.push_back (getAsOpFoldResult (ivs[0 ]));
405
395
outputOffsets.push_back (getAsOpFoldResult (ivs[1 ]));
@@ -409,8 +399,8 @@ struct MHAToFlashAttention
409
399
outputSizes[2 ] = rewriter.getIndexAttr (cfg.RowBlockSize );
410
400
outputSizes[3 ] = rewriter.getIndexAttr (headDim);
411
401
Value insertedRescaledOSlice = rewriter.create <tensor::InsertSliceOp>(
412
- loc, rescaledOSliceFinal , rowBlockLoop.getRegionIterArgs ()[0 ],
413
- outputOffsets, outputSizes, strides);
402
+ loc, OSliceFinal , rowBlockLoop.getRegionIterArgs ()[0 ], outputOffsets ,
403
+ outputSizes, strides);
414
404
rewriter.create <scf::YieldOp>(loc, ValueRange{insertedRescaledOSlice});
415
405
// Add the scf.yield operations for all the outer loops.
416
406
for (auto [outerLoop, innerLoop] :
0 commit comments