Skip to content

Commit 0e92ae9

Browse files
committed
[ConstraintElim] Add facts about non-poison intrinsics on demand
1 parent 2ba455f commit 0e92ae9

File tree

4 files changed

+87
-61
lines changed

4 files changed

+87
-61
lines changed

llvm/lib/Transforms/Scalar/ConstraintElimination.cpp

+72-51
Original file line numberDiff line numberDiff line change
@@ -1120,28 +1120,17 @@ void State::addInfoFor(BasicBlock &BB) {
11201120
}
11211121
break;
11221122
}
1123-
// Enqueue ssub_with_overflow for simplification.
1123+
// Enqueue intrinsic for simplification.
11241124
case Intrinsic::ssub_with_overflow:
11251125
case Intrinsic::ucmp:
11261126
case Intrinsic::scmp:
1127-
WorkList.push_back(
1128-
FactOrCheck::getCheck(DT.getNode(&BB), cast<CallInst>(&I)));
1129-
break;
1130-
// Enqueue the intrinsics to add extra info.
11311127
case Intrinsic::umin:
11321128
case Intrinsic::umax:
11331129
case Intrinsic::smin:
11341130
case Intrinsic::smax:
11351131
// TODO: handle llvm.abs as well
11361132
WorkList.push_back(
11371133
FactOrCheck::getCheck(DT.getNode(&BB), cast<CallInst>(&I)));
1138-
// TODO: Check if it is possible to instead only added the min/max facts
1139-
// when simplifying uses of the min/max intrinsics.
1140-
if (!isGuaranteedNotToBePoison(&I))
1141-
break;
1142-
[[fallthrough]];
1143-
case Intrinsic::abs:
1144-
WorkList.push_back(FactOrCheck::getInstFact(DT.getNode(&BB), &I));
11451134
break;
11461135
}
11471136

@@ -1385,10 +1374,64 @@ static void generateReproducer(CmpInst *Cond, Module *M,
13851374
assert(!verifyFunction(*F, &dbgs()));
13861375
}
13871376

1377+
static void addNonPoisonIntrinsicInstFact(
1378+
IntrinsicInst *II,
1379+
function_ref<void(CmpPredicate, Value *, Value *)> AddFact) {
1380+
Intrinsic::ID IID = II->getIntrinsicID();
1381+
switch (IID) {
1382+
case Intrinsic::umin:
1383+
case Intrinsic::umax:
1384+
case Intrinsic::smin:
1385+
case Intrinsic::smax: {
1386+
ICmpInst::Predicate Pred =
1387+
ICmpInst::getNonStrictPredicate(MinMaxIntrinsic::getPredicate(IID));
1388+
AddFact(Pred, II, II->getArgOperand(0));
1389+
AddFact(Pred, II, II->getArgOperand(1));
1390+
break;
1391+
}
1392+
case Intrinsic::abs: {
1393+
if (cast<ConstantInt>(II->getArgOperand(1))->isOne())
1394+
AddFact(CmpInst::ICMP_SGE, II, ConstantInt::get(II->getType(), 0));
1395+
AddFact(CmpInst::ICMP_SGE, II, II->getArgOperand(0));
1396+
break;
1397+
}
1398+
default:
1399+
break;
1400+
}
1401+
}
1402+
1403+
static void
1404+
removeEntryFromStack(const StackEntry &E, ConstraintInfo &Info,
1405+
SmallVectorImpl<StackEntry> &DFSInStack,
1406+
SmallVectorImpl<ReproducerEntry> *ReproducerCondStack) {
1407+
Info.popLastConstraint(E.IsSigned);
1408+
// Remove variables in the system that went out of scope.
1409+
auto &Mapping = Info.getValue2Index(E.IsSigned);
1410+
for (Value *V : E.ValuesToRelease)
1411+
Mapping.erase(V);
1412+
Info.popLastNVariables(E.IsSigned, E.ValuesToRelease.size());
1413+
DFSInStack.pop_back();
1414+
if (ReproducerCondStack)
1415+
ReproducerCondStack->pop_back();
1416+
}
1417+
13881418
static std::optional<bool> checkCondition(CmpInst::Predicate Pred, Value *A,
13891419
Value *B, Instruction *CheckInst,
13901420
ConstraintInfo &Info) {
13911421
LLVM_DEBUG(dbgs() << "Checking " << *CheckInst << "\n");
1422+
SmallVector<StackEntry, 8> DFSInStack;
1423+
auto StackRestorer = make_scope_exit([&]() {
1424+
while (!DFSInStack.empty())
1425+
removeEntryFromStack(DFSInStack.back(), Info, DFSInStack, nullptr);
1426+
});
1427+
auto AddFact = [&](CmpPredicate Pred, Value *A, Value *B) {
1428+
Info.addFact(Pred, A, B, 0, 0, DFSInStack);
1429+
};
1430+
1431+
if (auto *II = dyn_cast<IntrinsicInst>(A))
1432+
addNonPoisonIntrinsicInstFact(II, AddFact);
1433+
if (auto *II = dyn_cast<IntrinsicInst>(B))
1434+
addNonPoisonIntrinsicInstFact(II, AddFact);
13921435

13931436
auto R = Info.getConstraintForSolving(Pred, A, B);
13941437
if (R.empty() || !R.isValid(Info)){
@@ -1517,22 +1560,6 @@ static bool checkAndReplaceCmp(CmpIntrinsic *I, ConstraintInfo &Info,
15171560
return false;
15181561
}
15191562

1520-
static void
1521-
removeEntryFromStack(const StackEntry &E, ConstraintInfo &Info,
1522-
Module *ReproducerModule,
1523-
SmallVectorImpl<ReproducerEntry> &ReproducerCondStack,
1524-
SmallVectorImpl<StackEntry> &DFSInStack) {
1525-
Info.popLastConstraint(E.IsSigned);
1526-
// Remove variables in the system that went out of scope.
1527-
auto &Mapping = Info.getValue2Index(E.IsSigned);
1528-
for (Value *V : E.ValuesToRelease)
1529-
Mapping.erase(V);
1530-
Info.popLastNVariables(E.IsSigned, E.ValuesToRelease.size());
1531-
DFSInStack.pop_back();
1532-
if (ReproducerModule)
1533-
ReproducerCondStack.pop_back();
1534-
}
1535-
15361563
/// Check if either the first condition of an AND or OR is implied by the
15371564
/// (negated in case of OR) second condition or vice versa.
15381565
static bool checkOrAndOpImpliedByOther(
@@ -1554,8 +1581,8 @@ static bool checkOrAndOpImpliedByOther(
15541581
// Remove entries again.
15551582
while (OldSize < DFSInStack.size()) {
15561583
StackEntry E = DFSInStack.back();
1557-
removeEntryFromStack(E, Info, ReproducerModule, ReproducerCondStack,
1558-
DFSInStack);
1584+
removeEntryFromStack(E, Info, DFSInStack,
1585+
ReproducerModule ? &ReproducerCondStack : nullptr);
15591586
}
15601587
});
15611588
bool IsOr = match(JoinOp, m_LogicalOr());
@@ -1571,6 +1598,14 @@ static bool checkOrAndOpImpliedByOther(
15711598
Pred = CmpInst::getInversePredicate(Pred);
15721599
// Optimistically add fact from the other compares in the AND/OR.
15731600
Info.addFact(Pred, LHS, RHS, CB.NumIn, CB.NumOut, DFSInStack);
1601+
auto AddFact = [&](CmpPredicate Pred, Value *A, Value *B) {
1602+
Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
1603+
};
1604+
1605+
if (auto *II = dyn_cast<IntrinsicInst>(LHS))
1606+
addNonPoisonIntrinsicInstFact(II, AddFact);
1607+
if (auto *II = dyn_cast<IntrinsicInst>(RHS))
1608+
addNonPoisonIntrinsicInstFact(II, AddFact);
15741609
continue;
15751610
}
15761611
if (IsOr ? match(Val, m_LogicalOr(m_Value(LHS), m_Value(RHS)))
@@ -1807,8 +1842,8 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
18071842
Info.getValue2Index(E.IsSigned));
18081843
dbgs() << "\n";
18091844
});
1810-
removeEntryFromStack(E, Info, ReproducerModule.get(), ReproducerCondStack,
1811-
DFSInStack);
1845+
removeEntryFromStack(E, Info, DFSInStack,
1846+
ReproducerModule ? &ReproducerCondStack : nullptr);
18121847
}
18131848

18141849
// For a block, check if any CmpInsts become known based on the current set
@@ -1879,25 +1914,6 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
18791914
};
18801915

18811916
CmpPredicate Pred;
1882-
if (!CB.isConditionFact()) {
1883-
Value *X;
1884-
if (match(CB.Inst, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) {
1885-
// If is_int_min_poison is true then we may assume llvm.abs >= 0.
1886-
if (cast<ConstantInt>(CB.Inst->getOperand(1))->isOne())
1887-
AddFact(CmpInst::ICMP_SGE, CB.Inst,
1888-
ConstantInt::get(CB.Inst->getType(), 0));
1889-
AddFact(CmpInst::ICMP_SGE, CB.Inst, X);
1890-
continue;
1891-
}
1892-
1893-
if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(CB.Inst)) {
1894-
Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate());
1895-
AddFact(Pred, MinMax, MinMax->getLHS());
1896-
AddFact(Pred, MinMax, MinMax->getRHS());
1897-
continue;
1898-
}
1899-
}
1900-
19011917
Value *A = nullptr, *B = nullptr;
19021918
if (CB.isConditionFact()) {
19031919
Pred = CB.Cond.Pred;
@@ -1922,6 +1938,11 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
19221938
assert(Matched && "Must have an assume intrinsic with a icmp operand");
19231939
}
19241940
AddFact(Pred, A, B);
1941+
// Now both A and B is guaranteed not to be poison.
1942+
if (auto *II = dyn_cast<IntrinsicInst>(A))
1943+
addNonPoisonIntrinsicInstFact(II, AddFact);
1944+
if (auto *II = dyn_cast<IntrinsicInst>(B))
1945+
addNonPoisonIntrinsicInstFact(II, AddFact);
19251946
}
19261947

19271948
if (ReproducerModule && !ReproducerModule->functions().empty()) {

llvm/test/Transforms/ConstraintElimination/abs.ll

+10-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ define i1 @abs_plus_one(i32 %arg) {
2828
; CHECK-SAME: i32 [[ARG:%.*]]) {
2929
; CHECK-NEXT: [[ABS:%.*]] = tail call i32 @llvm.abs.i32(i32 [[ARG]], i1 true)
3030
; CHECK-NEXT: [[ABS_PLUS_ONE:%.*]] = add nsw i32 [[ABS]], 1
31-
; CHECK-NEXT: ret i1 true
31+
; CHECK-NEXT: [[CMP:%.*]] = icmp sge i32 [[ABS_PLUS_ONE]], [[ARG]]
32+
; CHECK-NEXT: ret i1 [[CMP]]
3233
;
3334
%abs = tail call i32 @llvm.abs.i32(i32 %arg, i1 true)
3435
%abs_plus_one = add nsw i32 %abs, 1
@@ -69,7 +70,8 @@ define i1 @abs_plus_one_unsigned_greater_or_equal_nonnegative_arg(i32 %arg) {
6970
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP_ARG_NONNEGATIVE]])
7071
; CHECK-NEXT: [[ABS:%.*]] = tail call i32 @llvm.abs.i32(i32 [[ARG]], i1 true)
7172
; CHECK-NEXT: [[ABS_PLUS_ONE:%.*]] = add nuw i32 [[ABS]], 1
72-
; CHECK-NEXT: ret i1 true
73+
; CHECK-NEXT: [[CMP:%.*]] = icmp uge i32 [[ABS_PLUS_ONE]], [[ARG]]
74+
; CHECK-NEXT: ret i1 [[CMP]]
7375
;
7476
%cmp_arg_nonnegative = icmp sge i32 %arg, 0
7577
call void @llvm.assume(i1 %cmp_arg_nonnegative)
@@ -107,7 +109,8 @@ define i1 @abs_constant_negative_arg() {
107109
define i1 @abs_constant_positive_arg() {
108110
; CHECK-LABEL: define i1 @abs_constant_positive_arg() {
109111
; CHECK-NEXT: [[ABS:%.*]] = tail call i32 @llvm.abs.i32(i32 3, i1 false)
110-
; CHECK-NEXT: ret i1 true
112+
; CHECK-NEXT: [[CMP:%.*]] = icmp sge i32 [[ABS]], 3
113+
; CHECK-NEXT: ret i1 [[CMP]]
111114
;
112115
%abs = tail call i32 @llvm.abs.i32(i32 3, i1 false)
113116
%cmp = icmp sge i32 %abs, 3
@@ -142,7 +145,8 @@ define i1 @abs_is_nonnegative_int_min_is_poison(i32 %arg) {
142145
; CHECK-LABEL: define i1 @abs_is_nonnegative_int_min_is_poison(
143146
; CHECK-SAME: i32 [[ARG:%.*]]) {
144147
; CHECK-NEXT: [[ABS:%.*]] = tail call i32 @llvm.abs.i32(i32 [[ARG]], i1 true)
145-
; CHECK-NEXT: ret i1 true
148+
; CHECK-NEXT: [[CMP:%.*]] = icmp sge i32 [[ABS]], 0
149+
; CHECK-NEXT: ret i1 [[CMP]]
146150
;
147151
%abs = tail call i32 @llvm.abs.i32(i32 %arg, i1 true)
148152
%cmp = icmp sge i32 %abs, 0
@@ -152,7 +156,8 @@ define i1 @abs_is_nonnegative_int_min_is_poison(i32 %arg) {
152156
define i1 @abs_is_nonnegative_constant_arg() {
153157
; CHECK-LABEL: define i1 @abs_is_nonnegative_constant_arg() {
154158
; CHECK-NEXT: [[ABS:%.*]] = tail call i32 @llvm.abs.i32(i32 -3, i1 true)
155-
; CHECK-NEXT: ret i1 true
159+
; CHECK-NEXT: [[CMP:%.*]] = icmp sge i32 [[ABS]], 0
160+
; CHECK-NEXT: ret i1 [[CMP]]
156161
;
157162
%abs = tail call i32 @llvm.abs.i32(i32 -3, i1 true)
158163
%cmp = icmp sge i32 %abs, 0

llvm/test/Transforms/ConstraintElimination/minmax.ll

+3-4
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ define i1 @umax_uge_ugt_with_add_nuw(i32 %x, i32 %y) {
222222
; CHECK-NEXT: [[CMP:%.*]] = icmp uge i32 [[Y]], [[SUM]]
223223
; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[END:%.*]]
224224
; CHECK: if:
225-
; CHECK-NEXT: ret i1 true
225+
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[Y]], [[X]]
226+
; CHECK-NEXT: ret i1 [[CMP2]]
226227
; CHECK: end:
227228
; CHECK-NEXT: ret i1 false
228229
;
@@ -306,9 +307,7 @@ define i1 @smin_branchless(i32 %x, i32 %y) {
306307
; CHECK-SAME: (i32 [[X:%.*]], i32 [[Y:%.*]]) {
307308
; CHECK-NEXT: entry:
308309
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
309-
; CHECK-NEXT: [[CMP1:%.*]] = icmp sle i32 [[MIN]], [[X]]
310-
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[MIN]], [[X]]
311-
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP1]], [[CMP2]]
310+
; CHECK-NEXT: [[RET:%.*]] = xor i1 true, false
312311
; CHECK-NEXT: ret i1 [[RET]]
313312
;
314313
entry:

llvm/test/Transforms/ConstraintElimination/umin-result-may-be-poison.ll

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ define i1 @umin_not_used(i32 %arg) {
2222
define i1 @umin_poison_is_UB_via_call(i32 %arg) {
2323
; CHECK-LABEL: define i1 @umin_poison_is_UB_via_call(
2424
; CHECK-SAME: i32 [[ARG:%.*]]) {
25+
; CHECK-NEXT: [[ICMP:%.*]] = icmp slt i32 [[ARG]], 0
2526
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i32 [[ARG]], 3
2627
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[SHL]], i32 80)
2728
; CHECK-NEXT: call void @noundef(i32 noundef [[MIN]])
2829
; CHECK-NEXT: [[CMP2:%.*]] = shl nuw nsw i32 [[ARG]], 3
29-
; CHECK-NEXT: ret i1 false
30+
; CHECK-NEXT: ret i1 [[ICMP]]
3031
;
3132
%icmp = icmp slt i32 %arg, 0
3233
%shl = shl nuw nsw i32 %arg, 3

0 commit comments

Comments
 (0)