Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make VNNI more robust #1001

Merged
merged 26 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class OpOperand;
class AffineDimExpr;
class AffineMap;
class VectorType;
class Operation;

namespace linalg {
class LinalgOp;
Expand All @@ -35,16 +36,25 @@ enum class VnniOperandRank {
BRGEMM_OUTS = 3
};

// Return the VNNI blocking factor: 2 for BF16 and 4 for BF8.
std::optional<int64_t> getVnniBlockingFactor(Type type);
// Return the VNNI blocking factor.
// Optionally, operation can be provided to give access to DLTI.
std::optional<int64_t> getVnniBlockingFactor(Type type,
Operation *op = nullptr);

// Return true if the memref is in VNNI layout with rank `expectedRank`.
bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref);
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref,
std::optional<int64_t> blockingFactor = std::nullopt);

// Return true if the vector is in VNNI layout with rank `expectedRank`.
bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector);
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector,
std::optional<int64_t> blockingFactor = std::nullopt);

bool isInVnniLayout(int64_t expectedRank, VectorType vector);
// Return true if the vector is in VNNI layout with rank `expectedRank`.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(int64_t expectedRank, VectorType vector,
std::optional<int64_t> blockingFactor = std::nullopt);

// Return true if the operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,9 @@ struct ConvertVnniPacking : public OpRewritePattern<linalg::TransposeOp> {
if (failed(stridesOnOutput) || stridesOnOutput->back() != 1)
return failure();
// Ajust ldo based on the VNNI factor.
unaryInfo.ldo = stridesOnOutput->front() /
*vnni::utils::getVnniBlockingFactor(out.getType());
unaryInfo.ldo =
stridesOnOutput->front() /
*vnni::utils::getVnniBlockingFactor(out.getType(), transposeOp);
auto flags = rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get(
rewriter.getContext(), xsmm::UnaryFlags::NONE));
xsmm::UnaryKindAttr kind =
Expand Down Expand Up @@ -1112,7 +1113,7 @@ struct ConvertGenericToVnniMatmulLikeOp
// Take the whole reduction dim size. Account for the VNNI factor (ensured
// by the earlier check) that splits the K dim in the shape.
std::optional<int64_t> vnniFactor =
vnni::utils::getVnniBlockingFactor(bufferB.getType());
vnni::utils::getVnniBlockingFactor(bufferB.getType(), genericOp);
if (!vnniFactor)
return rewriter.notifyMatchFailure(genericOp,
"failed to determine VNNI factor");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp,
if (vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::TRANSPOSE,
outType)) {
// Adjust ldo based on vnni factor
auto vnniFactor = *vnni::utils::getVnniBlockingFactor(outType);
auto vnniFactor = *vnni::utils::getVnniBlockingFactor(outType, transposeOp);
unaryInfo.ldo = unaryInfo.ldo / vnniFactor;
} else {
std::swap(unaryInfo.m, unaryInfo.n);
Expand Down
11 changes: 7 additions & 4 deletions lib/TPP/Dialect/Xsmm/XsmmOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,13 @@ LogicalResult GemmOp::verify() {
auto memref = dyn_cast<MemRefType>(memrefOperands[idx].getType());
assert(memref && (memref.getRank() == 2 || memref.getRank() == 3));

if (memref.getRank() == 3 &&
!vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM,
memref)) {
return emitOpError() << "expect VNNI layout for operand: " << actualIdx;
if (memref.getRank() == 3) {
auto vnniFactor = vnni::utils::getVnniBlockingFactor(memref);
if (!vnniFactor || (*vnniFactor) % 2 != 0 ||
!vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM,
memref)) {
return emitOpError() << "expect VNNI layout for operand: " << actualIdx;
}
}
}
return success();
Expand Down
9 changes: 6 additions & 3 deletions lib/TPP/Dialect/Xsmm/XsmmVerify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,24 @@ static LogicalResult verifyGemmDispatchAndInvokeLikeOp(InvokeTy gemmOp) {
: vnni::utils::VnniOperandRank::GEMM;

// VNNI flags must be consistent with the memref shapes.
auto vnniFactor = vnni::utils::getVnniBlockingFactor(operandA, gemmOp);
ArrayAttr flags = dispatchOp->getFlags();
for (auto flag : flags) {
int64_t gemmFlag = cast<IntegerAttr>(flag).getInt();
if (gemmFlag == static_cast<int64_t>(xsmm::GemmFlags::VNNI_A) &&
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA)) {
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA,
vnniFactor)) {
return gemmOp.emitOpError(
"expect VNNI layout for operand A or invalid VNNI_A flags");
}
if (gemmFlag == static_cast<int64_t>(xsmm::GemmFlags::VNNI_B) &&
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB)) {
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB,
vnniFactor)) {
return gemmOp.emitOpError(
"expect VNNI layout for operand B or invalid VNNI_B flags");
}
if (gemmFlag == static_cast<int64_t>(xsmm::GemmFlags::VNNI_C) &&
!vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC)) {
!vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC, vnniFactor)) {
return gemmOp.emitOpError(
"expect VNNI layout for operand C or invalid VNNI_C flags");
}
Expand Down
6 changes: 3 additions & 3 deletions lib/TPP/IR/MatcherUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ getIteratorPos(linalg::LinalgOp linalgOp, AffineMap indexingMap,
std::pair<bool, bool> isMatmulVnniOp(linalg::GenericOp linalgOp,
SmallVectorImpl<Value> *operands) {
bool hasBatch = false;
auto blockingFactor =
vnni::utils::getVnniBlockingFactor(linalgOp->getOperands()[0].getType());
auto blockingFactor = vnni::utils::getVnniBlockingFactor(
linalgOp->getOperands()[0].getType(), linalgOp);
if (!blockingFactor)
return std::make_pair(false, hasBatch);

Expand Down Expand Up @@ -115,7 +115,7 @@ std::pair<bool, bool> isMatmulVnniOp(linalg::GenericOp linalgOp,

// At this point, the operation is a valid matmul contraction.
// Finally, ensure that it is in VNNI layout.
bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp, *blockingFactor);
bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp);
return std::make_pair(isVnniMatmul, hasBatch);
}

Expand Down
10 changes: 6 additions & 4 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,13 @@ mlir::linalgx::packVNNIMatmulOp(RewriterBase &rewriter,

OpOperand &operandB = matmulOp->getOpOperand(1);
auto blockingFactor =
vnni::utils::getVnniBlockingFactor(operandB.get().getType());
vnni::utils::getVnniBlockingFactor(operandB.get().getType(), matmulOp);
if (!blockingFactor) {
return rewriter.notifyMatchFailure(matmulOp,
"unsupported blocking factor for type");
}

if (vnni::utils::isInVnniLayout(matmulOp, *blockingFactor)) {
if (vnni::utils::isInVnniLayout(matmulOp)) {
return rewriter.notifyMatchFailure(matmulOp, "already packed to VNNI");
}

Expand Down Expand Up @@ -409,12 +409,14 @@ mlir::linalgx::packVNNIBRGemmOp(RewriterBase &rewriter,

Value operandB = brgemmOp.getInputs()[1];
// Blocking factor on the `k` dimension.
auto blockingFactor = vnni::utils::getVnniBlockingFactor(operandB.getType());
auto blockingFactor =
vnni::utils::getVnniBlockingFactor(operandB.getType(), brgemmOp);
if (!blockingFactor) {
return rewriter.notifyMatchFailure(brgemmOp,
"unsupported blocking factor for type");
}
SmallVector<OpFoldResult> tilesOnK = {rewriter.getI64IntegerAttr(2)};
SmallVector<OpFoldResult> tilesOnK = {
rewriter.getI64IntegerAttr(*blockingFactor)};

Location loc = brgemmOp.getLoc();
// Reshape input A.
Expand Down
63 changes: 51 additions & 12 deletions lib/TPP/Transforms/Utils/VNNIUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "TPP/Transforms/Utils/VNNIUtils.h"

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand All @@ -20,22 +21,54 @@ namespace mlir {
namespace vnni {
namespace utils {

std::optional<int64_t> getVnniBlockingFactor(Type type) {
std::optional<int64_t> getVnniBlockingFactor(Type type, Operation *op) {
auto elementType = getElementTypeOrSelf(type);
if (elementType.isBF16())
if (elementType.isBF16()) {
// Check if a VNNI factor hint is associated to the IR via DLTI.
auto deriveVnniFromDLTI = [&]() -> std::optional<int64_t> {
if (!op)
return std::nullopt;
ModuleOp moduleOp = op->getParentOfType<mlir::ModuleOp>();
if (!moduleOp)
return std::nullopt;
TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec();
if (!sysSpec)
return std::nullopt;
auto deviceId = StringAttr::get(moduleOp->getContext(), "CPU");
auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId);
if (!deviceSpec)
return std::nullopt;
auto vnniId = StringAttr::get(moduleOp->getContext(), "vnni");
DataLayoutEntryInterface entry =
(*deviceSpec).getSpecForIdentifier(vnniId);
if (!entry)
return std::nullopt;
Attribute value = entry.getValue();
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(value))
return intAttr.getInt();
return std::nullopt;
};
if (auto vnniFactor = deriveVnniFromDLTI())
return *vnniFactor;

return libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16);
}
return std::nullopt;
}

// Until we have a better way to express the VNNI layout (see: #563), it is up
// to the callee to specify the expected rank in the VNNI layout as the rank
// depends on the operations we are dealing with.
bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref) {
bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref,
std::optional<int64_t> blockingFactor) {
if (memref.getRank() != static_cast<int64_t>(expectedRank) ||
!memref.getElementType().isBF16()) {
!memref.getElementType().isBF16())
return false;
}
return memref.getShape().back() == vnni::utils::getVnniBlockingFactor(memref);

if (blockingFactor && memref.getShape().back() != *blockingFactor)
return false;

return true;
}

bool isInVnniLayout(linalg::LinalgOp linalgOp,
Expand Down Expand Up @@ -109,15 +142,21 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp,
return true;
}

bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector) {
return isInVnniLayout(static_cast<int64_t>(expectedRank), vector);
bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector,
std::optional<int64_t> blockingFactor) {
return isInVnniLayout(static_cast<int64_t>(expectedRank), vector,
blockingFactor);
}

bool isInVnniLayout(int64_t expectedRank, VectorType vector) {
if (vector.getRank() != expectedRank || !vector.getElementType().isBF16()) {
bool isInVnniLayout(int64_t expectedRank, VectorType vector,
std::optional<int64_t> blockingFactor) {
if (vector.getRank() != expectedRank || !vector.getElementType().isBF16())
return false;
}
return vector.getShape().back() == vnni::utils::getVnniBlockingFactor(vector);

if (blockingFactor && vector.getShape().back() != *blockingFactor)
return false;

return true;
}

} // namespace utils
Expand Down
23 changes: 23 additions & 0 deletions test/BF16/Integration/lit.local.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@ def has_support(feature):

return True

def is_arch(target):
# Arch detection not working on Windows
if sys.platform in ['win32']:
return False

try:
cmd = subprocess.Popen(
['uname', '-m'], stdout=subprocess.PIPE)
except OSError:
return False

out = cmd.stdout.read().decode('ascii')
cmd.wait()

return target in out


# AVX512 and SVE should have BF16 support
if not has_support('avx512') and not has_support('avx2') and not has_support('sve'):
config.unsupported = True

# Enable only on x86
# Other targets may use different VNNI blocking scheme that is not compatible with
# prepacked shapes in some of the tests
if not is_arch('x86'):
config.unsupported = True
28 changes: 14 additions & 14 deletions test/BF16/Integration/mlir-gen-bf16.mlir
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
// MLP without softmax (can't print packed version for now)
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void
// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void
// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void

// Matmul only
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void
// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void
// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void

// Kernel - matmul
// RUN: mlir-gen --kernel=args --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16
// RUN: mlir-gen --output=named --kernel=args --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16
// RUN: mlir-gen --kernel=args --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16
// RUN: mlir-gen --output=named --kernel=args --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16

// Kernel - fc
// RUN: mlir-gen --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16
// RUN: mlir-gen --output=named --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16
// RUN: mlir-gen --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16
// RUN: mlir-gen --output=named --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16

// BF16/VNNI execution
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF
// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF
// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF
// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF
// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF


// GEN-MATMUL-BF16: ( 11, 11, 11, 11, 11, 11, 11, 11, 11, 11 )
// GEN-MATMUL-BF16: ( 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17 )

// GEN-FC-BF16: ( 12, 12, 12, 12, 12, 12, 12, 12, 12, 12 )
// GEN-FC-BF16: ( 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18 )

// PERF: {{[0-9]+}}{{.?}}{{[0-9e-]+}}
2 changes: 1 addition & 1 deletion test/BF16/Integration/tpp-run-splat-shape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func.func @entry(%arg0: tensor<4x8x8x8xbf16>, %output: tensor<4x8x8x8xbf16>) ->
// due to compile time packing.
// CHECK-NOT: memref.global "private" constant @__constant_{{.*}}: memref<8x8xbf16>
// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<4x8x8x8xbf16>
// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<8x8x4x8x2xbf16>
// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<8x8x{{[4|2]}}x8x{{2|4}}xbf16>
// CHECK: xsmm_brgemm_invoke
// CHECK: xsmm_binary_invoke
// CHECK: xsmm_unary_invoke
29 changes: 8 additions & 21 deletions test/BF16/Integration/vnni-xsmm-vs-loops.mlir
Original file line number Diff line number Diff line change
@@ -1,26 +1,13 @@
// RUN: tpp-run %s -print -seed 123 \
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 \
// RUN: --tiles=16,16,16 --float-type=bf16 | \
// RUN: tpp-opt --pack-vnni | \
// RUN: tpp-run -print -seed 123 \
// RUN: -e entry -entry-point-result=void > %t.xsmm

// RUN: tpp-run %s -print -seed 123 -linalg-to-loops \
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 \
// RUN: --tiles=16,16,16 --float-type=bf16 | \
// RUN: tpp-opt --pack-vnni | \
// RUN: tpp-run -print -seed 123 -linalg-to-loops \
// RUN: -e entry -entry-point-result=void > %t.loops

// RUN: fpcmp -r 0.01 %t.xsmm %t.loops

#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6, d5, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)>

func.func @entry(%arg0: tensor<2x2x7x4x2xbf16>, %arg1: tensor<2x2x4x5x2xbf16>,
%arg2: tensor<2x2x7x5xbf16>) -> tensor<2x2x7x5xbf16> {
%1 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : tensor<2x2x7x4x2xbf16>, tensor<2x2x4x5x2xbf16>)
outs(%arg2 : tensor<2x2x7x5xbf16>) {
^bb0(%in: bf16, %in_0: bf16, %out: bf16):
%2 = arith.mulf %in, %in_0 : bf16
%3 = arith.addf %out, %2 : bf16
linalg.yield %3 : bf16
} -> tensor<2x2x7x5xbf16>
return %1 : tensor<2x2x7x5xbf16>
}
Loading