Skip to content

Commit fe63c2f

Browse files
[LVL][CSA] Legalize CSA vectorization
1 parent 87edc73 commit fe63c2f

File tree

7 files changed

+72
-4
lines changed

7 files changed

+72
-4
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+9
Original file line numberDiff line numberDiff line change
@@ -1767,6 +1767,10 @@ class TargetTransformInfo {
17671767
: EVLParamStrategy(EVLParamStrategy), OpStrategy(OpStrategy) {}
17681768
};
17691769

1770+
/// \returns true if the loop vectorizer should vectorize conditional
1771+
/// scalar assignments for the target.
1772+
bool enableCSAVectorization() const;
1773+
17701774
/// \returns How the target needs this vector-predicated operation to be
17711775
/// transformed.
17721776
VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const;
@@ -2175,6 +2179,7 @@ class TargetTransformInfo::Concept {
21752179
virtual bool supportsScalableVectors() const = 0;
21762180
virtual bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
21772181
Align Alignment) const = 0;
2182+
virtual bool enableCSAVectorization() const = 0;
21782183
virtual VPLegalization
21792184
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
21802185
virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -2940,6 +2945,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
29402945
return Impl.hasActiveVectorLength(Opcode, DataType, Alignment);
29412946
}
29422947

2948+
bool enableCSAVectorization() const override {
2949+
return Impl.enableCSAVectorization();
2950+
}
2951+
29432952
VPLegalization
29442953
getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
29452954
return Impl.getVPLegalizationStrategy(PI);

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+2
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,8 @@ class TargetTransformInfoImplBase {
956956
return false;
957957
}
958958

959+
bool enableCSAVectorization() const { return false; }
960+
959961
TargetTransformInfo::VPLegalization
960962
getVPLegalizationStrategy(const VPIntrinsic &PI) const {
961963
return TargetTransformInfo::VPLegalization(

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

+18
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#define LLVM_TRANSFORMS_VECTORIZE_LOOPVECTORIZATIONLEGALITY_H
2828

2929
#include "llvm/ADT/MapVector.h"
30+
#include "llvm/Analysis/CSADescriptors.h"
3031
#include "llvm/Analysis/LoopAccessAnalysis.h"
3132
#include "llvm/Support/TypeSize.h"
3233
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -257,6 +258,10 @@ class LoopVectorizationLegality {
257258
/// induction descriptor.
258259
using InductionList = MapVector<PHINode *, InductionDescriptor>;
259260

261+
/// CSAList contains the CSA descriptors for all the CSAs that were found
262+
/// in the loop, rooted by their phis.
263+
using CSAList = MapVector<PHINode *, CSADescriptor>;
264+
260265
/// RecurrenceSet contains the phi nodes that are recurrences other than
261266
/// inductions and reductions.
262267
using RecurrenceSet = SmallPtrSet<const PHINode *, 8>;
@@ -309,6 +314,12 @@ class LoopVectorizationLegality {
309314
/// Returns True if V is a Phi node of an induction variable in this loop.
310315
bool isInductionPhi(const Value *V) const;
311316

317+
/// Returns the CSAs found in the loop.
318+
const CSAList& getCSAs() const { return CSAs; }
319+
320+
/// Returns true if Phi is the root of a CSA in the loop.
321+
bool isCSAPhi(PHINode *Phi) const { return CSAs.count(Phi) != 0; }
322+
312323
/// Returns a pointer to the induction descriptor, if \p Phi is an integer or
313324
/// floating point induction.
314325
const InductionDescriptor *getIntOrFpInductionDescriptor(PHINode *Phi) const;
@@ -463,6 +474,10 @@ class LoopVectorizationLegality {
463474
void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID,
464475
SmallPtrSetImpl<Value *> &AllowedExit);
465476

477+
// Updates the vetorization state by adding \p Phi to the CSA list.
478+
void addCSAPhi(PHINode *Phi, const CSADescriptor &CSADesc,
479+
SmallPtrSetImpl<Value *> &AllowedExit);
480+
466481
/// The loop that we evaluate.
467482
Loop *TheLoop;
468483

@@ -507,6 +522,9 @@ class LoopVectorizationLegality {
507522
/// variables can be pointers.
508523
InductionList Inductions;
509524

525+
/// Holds the conditional scalar assignments
526+
CSAList CSAs;
527+
510528
/// Holds all the casts that participate in the update chain of the induction
511529
/// variables, and that have been proven to be redundant (possibly under a
512530
/// runtime guard). These casts can be ignored when creating the vectorized

llvm/lib/Analysis/TargetTransformInfo.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,10 @@ bool TargetTransformInfo::preferEpilogueVectorization() const {
13041304
return TTIImpl->preferEpilogueVectorization();
13051305
}
13061306

1307+
bool TargetTransformInfo::enableCSAVectorization() const {
1308+
return TTIImpl->enableCSAVectorization();
1309+
}
1310+
13071311
TargetTransformInfo::VPLegalization
13081312
TargetTransformInfo::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
13091313
return TTIImpl->getVPLegalizationStrategy(VPI);

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1985,6 +1985,11 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
19851985
C2.ScaleCost, C2.ImmCost, C2.SetupCost);
19861986
}
19871987

1988+
bool RISCVTTIImpl::enableCSAVectorization() const {
1989+
return ST->hasVInstructions() &&
1990+
ST->getProcFamily() == RISCVSubtarget::SiFive7;
1991+
}
1992+
19881993
bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
19891994
auto *VTy = dyn_cast<VectorType>(DataTy);
19901995
if (!VTy || VTy->isScalableTy())

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

+4
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
287287
return TLI->isVScaleKnownToBeAPowerOfTwo();
288288
}
289289

290+
/// \returns true if the loop vectorizer should vectorize conditional
291+
/// scalar assignments for the target.
292+
bool enableCSAVectorization() const;
293+
290294
/// \returns How the target needs this vector-predicated operation to be
291295
/// transformed.
292296
TargetTransformInfo::VPLegalization

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

+30-4
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ static cl::opt<LoopVectorizeHints::ScalableForceKind>
7979
"Scalable vectorization is available and favored when the "
8080
"cost is inconclusive.")));
8181

82+
static cl::opt<bool>
83+
EnableCSA("enable-csa-vectorization", cl::init(false), cl::Hidden,
84+
cl::desc("Control whether CSA loop vectorization is enabled"));
85+
8286
/// Maximum vectorization interleave count.
8387
static const unsigned MaxInterleaveFactor = 16;
8488

@@ -749,6 +753,15 @@ bool LoopVectorizationLegality::setupOuterLoopInductions() {
749753
return llvm::all_of(Header->phis(), IsSupportedPhi);
750754
}
751755

756+
void LoopVectorizationLegality::addCSAPhi(
757+
PHINode *Phi, const CSADescriptor &CSADesc,
758+
SmallPtrSetImpl<Value *> &AllowedExit) {
759+
assert(CSADesc.isValid() && "Expected Valid CSADescriptor");
760+
LLVM_DEBUG(dbgs() << "LV: found legal CSA opportunity" << *Phi << "\n");
761+
AllowedExit.insert(Phi);
762+
CSAs.insert({Phi, CSADesc});
763+
}
764+
752765
/// Checks if a function is scalarizable according to the TLI, in
753766
/// the sense that it should be vectorized and then expanded in
754767
/// multiple scalar calls. This is represented in the
@@ -866,14 +879,23 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
866879
continue;
867880
}
868881

869-
// As a last resort, coerce the PHI to a AddRec expression
870-
// and re-try classifying it a an induction PHI.
882+
// Try to coerce the PHI to a AddRec expression and re-try classifying
883+
// it a an induction PHI.
871884
if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true) &&
872885
!IsDisallowedStridedPointerInduction(ID)) {
873886
addInductionPhi(Phi, ID, AllowedExit);
874887
continue;
875888
}
876889

890+
// Check if the PHI can be classified as a CSA PHI.
891+
if (EnableCSA || (TTI->enableCSAVectorization() &&
892+
EnableCSA.getNumOccurrences() == 0)) {
893+
if (auto CSADesc = CSADescriptor::isCSAPhi(Phi, TheLoop)) {
894+
addCSAPhi(Phi, CSADesc, AllowedExit);
895+
continue;
896+
}
897+
}
898+
877899
reportVectorizationFailure("Found an unidentified PHI",
878900
"value that could not be identified as "
879901
"reduction is used outside the loop",
@@ -1555,11 +1577,15 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
15551577
for (const auto &Reduction : getReductionVars())
15561578
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
15571579

1580+
SmallPtrSet<const Value *, 8> CSALiveOuts;
1581+
for (const auto &CSA: getCSAs())
1582+
CSALiveOuts.insert(CSA.second.getAssignment());
1583+
15581584
// TODO: handle non-reduction outside users when tail is folded by masking.
15591585
for (auto *AE : AllowedExit) {
15601586
// Check that all users of allowed exit values are inside the loop or
1561-
// are the live-out of a reduction.
1562-
if (ReductionLiveOuts.count(AE))
1587+
// are the live-out of a reduction or a CSA
1588+
if (ReductionLiveOuts.count(AE) || CSALiveOuts.count(AE))
15631589
continue;
15641590
for (User *U : AE->users()) {
15651591
Instruction *UI = cast<Instruction>(U);

0 commit comments

Comments
 (0)