Skip to content

Commit 1e31c22

Browse files
[VPlan] Add cost model for CSA
1 parent 7600b9a commit 1e31c22

File tree

4 files changed

+236
-171
lines changed

4 files changed

+236
-171
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

+12-5
Original file line numberDiff line numberDiff line change
@@ -7279,10 +7279,17 @@ InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan,
72797279
/// not have corresponding recipes in \p Plan and are not marked to be ignored
72807280
/// in \p CostCtx. This means the VPlan contains simplification that the legacy
72817281
/// cost-model did not account for.
7282-
static bool
7283-
planContainsAdditionalSimplifications(VPlan &Plan, ElementCount VF,
7284-
VPCostContext &CostCtx, Loop *TheLoop,
7285-
LoopVectorizationCostModel &CM) {
7282+
static bool planContainsAdditionalSimplifications(
7283+
VPlan &Plan, ElementCount VF, VPCostContext &CostCtx, Loop *TheLoop,
7284+
LoopVectorizationCostModel &CM, LoopVectorizationLegality &Legal) {
7285+
7286+
// CSA cost is more complicated since there is significant overhead in the
7287+
// preheader and middle block. It also contains recipes that are not backed by
7288+
// underlying instructions in the original loop. This makes it difficult to
7289+
// model in the legacy cost model.
7290+
if (!Legal.getCSAs().empty())
7291+
return true;
7292+
72867293
// First collect all instructions for the recipes in Plan.
72877294
auto GetInstructionForCost = [](const VPRecipeBase *R) -> Instruction * {
72887295
if (auto *S = dyn_cast<VPSingleDefRecipe>(R))
@@ -7394,7 +7401,7 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() {
73947401
assert((BestFactor.Width == LegacyVF.Width ||
73957402
planContainsAdditionalSimplifications(getPlanFor(BestFactor.Width),
73967403
BestFactor.Width, CostCtx,
7397-
OrigLoop, CM)) &&
7404+
OrigLoop, CM, *Legal)) &&
73987405
" VPlan cost model and legacy cost model disagreed");
73997406
assert((BestFactor.Width.isScalar() || BestFactor.ScalarCost > 0) &&
74007407
"when vectorizing, the scalar cost must be computed.");

llvm/lib/Transforms/Vectorize/VPlan.h

+10-1
Original file line numberDiff line numberDiff line change
@@ -2498,6 +2498,9 @@ class VPCSAHeaderPHIRecipe final : public VPHeaderPHIRecipe {
24982498

24992499
void execute(VPTransformState &State) override;
25002500

2501+
InstructionCost computeCost(ElementCount VF,
2502+
VPCostContext &Ctx) const override;
2503+
25012504
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
25022505
/// Print the recipe.
25032506
void print(raw_ostream &O, const Twine &Indent,
@@ -2529,6 +2532,9 @@ class VPCSADataUpdateRecipe final : public VPSingleDefRecipe {
25292532

25302533
void execute(VPTransformState &State) override;
25312534

2535+
InstructionCost computeCost(ElementCount VF,
2536+
VPCostContext &Ctx) const override;
2537+
25322538
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
25332539
/// Print the recipe.
25342540
void print(raw_ostream &O, const Twine &Indent,
@@ -2575,6 +2581,9 @@ class VPCSAExtractScalarRecipe final : public VPSingleDefRecipe {
25752581

25762582
void execute(VPTransformState &State) override;
25772583

2584+
InstructionCost computeCost(ElementCount VF,
2585+
VPCostContext &Ctx) const override;
2586+
25782587
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
25792588
/// Print the recipe.
25802589
void print(raw_ostream &O, const Twine &Indent,
@@ -2585,7 +2594,7 @@ class VPCSAExtractScalarRecipe final : public VPSingleDefRecipe {
25852594
VPValue *getVPMaskSel() const { return getOperand(1); }
25862595
VPValue *getVPDataSel() const { return getOperand(2); }
25872596
VPValue *getVPCSAVLSel() const { return getOperand(3); }
2588-
bool usesEVL() { return getNumOperands() == 4; }
2597+
bool usesEVL() const { return getNumOperands() == 4; }
25892598
};
25902599

25912600
/// VPPredInstPHIRecipe is a recipe for generating the phi nodes needed when

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

+100
Original file line numberDiff line numberDiff line change
@@ -2148,6 +2148,24 @@ void VPCSAHeaderPHIRecipe::execute(VPTransformState &State) {
21482148
State.set(this, DataPhi, Part);
21492149
}
21502150

2151+
InstructionCost VPCSAHeaderPHIRecipe::computeCost(ElementCount VF,
2152+
VPCostContext &Ctx) const {
2153+
if (VF.isScalar())
2154+
return 0;
2155+
2156+
InstructionCost C = 0;
2157+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2158+
const TargetTransformInfo &TTI = Ctx.TTI;
2159+
2160+
// FIXME: These costs should be moved into VPInstruction::computeCost. We put
2161+
// them here for now since there is no VPInstruction::computeCost support.
2162+
// CSAInitMask
2163+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
2164+
// CSAInitData
2165+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VTy);
2166+
return C;
2167+
}
2168+
21512169
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
21522170
void VPCSADataUpdateRecipe::print(raw_ostream &O, const Twine &Indent,
21532171
VPSlotTracker &SlotTracker) const {
@@ -2176,6 +2194,34 @@ void VPCSADataUpdateRecipe::execute(VPTransformState &State) {
21762194
}
21772195
}
21782196

2197+
InstructionCost VPCSADataUpdateRecipe::computeCost(ElementCount VF,
2198+
VPCostContext &Ctx) const {
2199+
if (VF.isScalar())
2200+
return 0;
2201+
2202+
InstructionCost C = 0;
2203+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2204+
auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
2205+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2206+
const TargetTransformInfo &TTI = Ctx.TTI;
2207+
2208+
// Data Update
2209+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2210+
2211+
// FIXME: These costs should be moved into VPInstruction::computeCost. We put
2212+
// them here for now since they are related to updating the data and there is
2213+
// no VPInstruction::computeCost support at the moment. CSAInitMask AnyActive
2214+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2215+
// vp.reduce.or
2216+
C += TTI.getArithmeticReductionCost(Instruction::Or, VTy, std::nullopt,
2217+
CostKind);
2218+
// VPVLSel
2219+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy, CostKind);
2220+
// MaskUpdate
2221+
C += TTI.getArithmeticInstrCost(Instruction::Select, MaskTy, CostKind);
2222+
return C;
2223+
}
2224+
21792225
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
21802226
void VPCSAExtractScalarRecipe::print(raw_ostream &O, const Twine &Indent,
21812227
VPSlotTracker &SlotTracker) const {
@@ -2236,6 +2282,60 @@ void VPCSAExtractScalarRecipe::execute(VPTransformState &State) {
22362282
State.set(this, ChooseFromVecOrInit, 0, /*IsScalar=*/true);
22372283
}
22382284

2285+
InstructionCost
2286+
VPCSAExtractScalarRecipe::computeCost(ElementCount VF,
2287+
VPCostContext &Ctx) const {
2288+
if (VF.isScalar())
2289+
return 0;
2290+
2291+
InstructionCost C = 0;
2292+
auto *VTy = VectorType::get(getUnderlyingValue()->getType(), VF);
2293+
auto *Int32VTy =
2294+
VectorType::get(IntegerType::getInt32Ty(VTy->getContext()), VF);
2295+
auto *MaskTy = VectorType::get(IntegerType::getInt1Ty(VTy->getContext()), VF);
2296+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2297+
const TargetTransformInfo &TTI = Ctx.TTI;
2298+
2299+
// StepVector
2300+
ArrayRef<Value *> Args;
2301+
IntrinsicCostAttributes CostAttrs(Intrinsic::stepvector, Int32VTy, Args);
2302+
C += TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
2303+
// NegOneSplat
2304+
C += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, Int32VTy);
2305+
// LastIdx
2306+
if (usesEVL()) {
2307+
C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
2308+
CostKind);
2309+
} else {
2310+
// ActiveLaneIdxs
2311+
C += TTI.getArithmeticInstrCost(Instruction::Select,
2312+
MaskTy->getScalarType(), CostKind);
2313+
// MaybeLastIdx
2314+
C += TTI.getMinMaxReductionCost(Intrinsic::smax, Int32VTy, FastMathFlags(),
2315+
CostKind);
2316+
// IsLaneZeroActive
2317+
C += TTI.getArithmeticInstrCost(Instruction::ExtractElement, MaskTy,
2318+
CostKind);
2319+
// MaybeLastIdxEQZero
2320+
C += TTI.getArithmeticInstrCost(Instruction::ICmp, MaskTy->getScalarType(),
2321+
CostKind);
2322+
// And
2323+
C += TTI.getArithmeticInstrCost(Instruction::And, MaskTy->getScalarType(),
2324+
CostKind);
2325+
// LastIdx
2326+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy->getScalarType(),
2327+
CostKind);
2328+
}
2329+
// ExtractFromVec
2330+
C += TTI.getArithmeticInstrCost(Instruction::ExtractElement, VTy, CostKind);
2331+
// LastIdxGeZero
2332+
C += TTI.getArithmeticInstrCost(Instruction::ICmp, Int32VTy, CostKind);
2333+
// ChooseFromVecOrInit
2334+
C += TTI.getArithmeticInstrCost(Instruction::Select, VTy->getScalarType(),
2335+
CostKind);
2336+
return C;
2337+
}
2338+
22392339
void VPBranchOnMaskRecipe::execute(VPTransformState &State) {
22402340
assert(State.Instance && "Branch on Mask works only on single instance.");
22412341

0 commit comments

Comments
 (0)