From c5d9bad8e1e39d53f2d05e31ab1898b32d29f743 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 16 Jan 2025 19:17:24 +0100 Subject: [PATCH] Verify valid factor --- lib/TPP/Dialect/Xsmm/XsmmVerify.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp b/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp index df8071ba8..384ec7ddc 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp @@ -72,23 +72,25 @@ static LogicalResult verifyGemmDispatchAndInvokeLikeOp(InvokeTy gemmOp) { // VNNI flags must be consistent with the memref shapes. auto vnniFactor = vnni::utils::getVnniBlockingFactor(operandA, gemmOp); + ArrayAttr flags = dispatchOp->getFlags(); for (auto flag : flags) { int64_t gemmFlag = cast(flag).getInt(); if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_A) && - !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA, - vnniFactor)) { + (!vnniFactor || !vnni::utils::isInVnniLayout(expectedVnniRankIns, + operandA, vnniFactor))) { return gemmOp.emitOpError( "expect VNNI layout for operand A or invalid VNNI_A flags"); } if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_B) && - !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB, - vnniFactor)) { + (!vnniFactor || !vnni::utils::isInVnniLayout(expectedVnniRankIns, + operandB, vnniFactor))) { return gemmOp.emitOpError( "expect VNNI layout for operand B or invalid VNNI_B flags"); } if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_C) && - !vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC, vnniFactor)) { + (!vnniFactor || !vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC, + vnniFactor))) { return gemmOp.emitOpError( "expect VNNI layout for operand C or invalid VNNI_C flags"); }