Skip to content

[InstCombine] Fold lshr -> zext -> shl patterns #147737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 147 additions & 52 deletions llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,112 +530,159 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
return nullptr;
}

/// Return true if we can simplify two logical (either left or right) shifts
/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
Instruction *InnerShift,
InstCombinerImpl &IC, Instruction *CxtI) {
/// Return a bitmask of all constant outer shift amounts that can be simplified
/// by foldShiftedShift().
static APInt getEvaluableShiftedShiftMask(bool IsOuterShl,
Instruction *InnerShift,
InstCombinerImpl &IC,
Instruction *CxtI) {
assert(InnerShift->isLogicalShift() && "Unexpected instruction type");

const unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();

// We need constant scalar or constant splat shifts.
const APInt *InnerShiftConst;
if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
return false;
return APInt::getZero(TypeWidth);

// Two logical shifts in the same direction:
if (InnerShiftConst->uge(TypeWidth))
return APInt::getZero(TypeWidth);

const unsigned InnerShAmt = InnerShiftConst->getZExtValue();

// Two logical shifts in the same direction can always be simplified, so long
// as the total shift amount is legal.
// shl (shl X, C1), C2 --> shl X, C1 + C2
// lshr (lshr X, C1), C2 --> lshr X, C1 + C2
bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
if (IsInnerShl == IsOuterShl)
return true;
return APInt::getLowBitsSet(TypeWidth, TypeWidth - InnerShAmt);

APInt ShMask = APInt::getZero(TypeWidth);
// Equal shift amounts in opposite directions become bitwise 'and':
// lshr (shl X, C), C --> and X, C'
// shl (lshr X, C), C --> and X, C'
if (*InnerShiftConst == OuterShAmt)
return true;
ShMask.setBit(InnerShAmt);

// If the 2nd shift is bigger than the 1st, we can fold:
// If the inner shift is bigger than the outer, we can fold:
// lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3
// shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
// but it isn't profitable unless we know the and'd out bits are already zero.
// Also, check that the inner shift is valid (less than the type width) or
// we'll crash trying to produce the bit mask for the 'and'.
unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) {
unsigned InnerShAmt = InnerShiftConst->getZExtValue();
unsigned MaskShift =
IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, CxtI))
return true;
}

return false;
// but it isn't profitable unless we know the masked out bits are already
// zero.
KnownBits Known = IC.computeKnownBits(InnerShift->getOperand(0), CxtI);
// Isolate the bits that are annihilated by the inner shift.
APInt InnerShMask = IsInnerShl ? Known.Zero.lshr(TypeWidth - InnerShAmt)
: Known.Zero.trunc(InnerShAmt);
// Isolate the upper (resp. lower) InnerShAmt bits of the base operand of the
// inner shl (resp. lshr).
// Then:
// - lshr (shl X, C1), C2 == (shl X, C1 - C2) if the bottom C2 of the isolated
// bits are zero
// - shl (lshr X, C1), C2 == (lshr X, C1 - C2) if the top C2 of the isolated
// bits are zero
const unsigned MaxOuterShAmt =
IsInnerShl ? Known.Zero.lshr(TypeWidth - InnerShAmt).countr_one()
: Known.Zero.trunc(InnerShAmt).countl_one();
ShMask.setLowBits(MaxOuterShAmt);
return ShMask;
}

/// See if we can compute the specified value, but shifted logically to the left
/// or right by some number of bits. This should return true if the expression
/// can be computed for the same cost as the current expression tree. This is
/// used to eliminate extraneous shifting from things like:
/// %C = shl i128 %A, 64
/// %D = shl i128 %B, 96
/// %E = or i128 %C, %D
/// %F = lshr i128 %E, 64
/// where the client will ask if E can be computed shifted right by 64-bits. If
/// this succeeds, getShiftedValue() will be called to produce the value.
static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
InstCombinerImpl &IC, Instruction *CxtI) {
/// Given a bitmask \p ShiftMask of desired shift amounts, determine the submask
/// of bits corresponding to shift amounts X for which the given expression \p V
/// can be computed for at worst the same cost as the current expression tree
/// when shifted by X. For each set bit in the \p ShiftMask afterward,
/// getShiftedValue() can produce the corresponding value.
///
/// \returns true if and only if at least one bit of the \p ShiftMask is set
/// after refinement.
static bool refineEvaluableShiftMask(Value *V, APInt &ShiftMask,
bool IsLeftShift, InstCombinerImpl &IC,
Instruction *CxtI) {
// We can always evaluate immediate constants.
if (match(V, m_ImmConstant()))
return true;

Instruction *I = dyn_cast<Instruction>(V);
if (!I) return false;
if (!I) {
ShiftMask.clearAllBits();
return false;
}

// We can't mutate something that has multiple uses: doing so would
// require duplicating the instruction in general, which isn't profitable.
if (!I->hasOneUse()) return false;
if (!I->hasOneUse()) {
ShiftMask.clearAllBits();
return false;
}

switch (I->getOpcode()) {
default: return false;
default: {
ShiftMask.clearAllBits();
return false;
}
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
// Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
return refineEvaluableShiftMask(I->getOperand(0), ShiftMask, IsLeftShift,
IC, I) &&
refineEvaluableShiftMask(I->getOperand(1), ShiftMask, IsLeftShift,
IC, I);

case Instruction::Shl:
case Instruction::LShr:
return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
case Instruction::LShr: {
ShiftMask &= getEvaluableShiftedShiftMask(IsLeftShift, I, IC, CxtI);
return !ShiftMask.isZero();
}

case Instruction::Select: {
SelectInst *SI = cast<SelectInst>(I);
Value *TrueVal = SI->getTrueValue();
Value *FalseVal = SI->getFalseValue();
return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
return refineEvaluableShiftMask(TrueVal, ShiftMask, IsLeftShift, IC, SI) &&
refineEvaluableShiftMask(FalseVal, ShiftMask, IsLeftShift, IC, SI);
}
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
// get into trouble with cyclic PHIs here because we only consider
// instructions with a single use.
PHINode *PN = cast<PHINode>(I);
for (Value *IncValue : PN->incoming_values())
if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
if (!refineEvaluableShiftMask(IncValue, ShiftMask, IsLeftShift, IC, PN))
return false;
return true;
}
case Instruction::Mul: {
const APInt *MulConst;
// We can fold (shr (mul X, -(1 << C)), C) -> (and (neg X), C`)
return !IsLeftShift && match(I->getOperand(1), m_APInt(MulConst)) &&
MulConst->isNegatedPowerOf2() && MulConst->countr_zero() == NumBits;
if (IsLeftShift || !match(I->getOperand(1), m_APInt(MulConst)) ||
!MulConst->isNegatedPowerOf2()) {
ShiftMask.clearAllBits();
return false;
}
ShiftMask &=
APInt::getOneBitSet(ShiftMask.getBitWidth(), MulConst->countr_zero());
return !ShiftMask.isZero();
}
}
}

/// See if we can compute the specified value, but shifted logically to the left
/// or right by some number of bits. This should return true if the expression
/// can be computed for the same cost as the current expression tree. This is
/// used to eliminate extraneous shifting from things like:
/// %C = shl i128 %A, 64
/// %D = shl i128 %B, 96
/// %E = or i128 %C, %D
/// %F = lshr i128 %E, 64
/// where the client will ask if E can be computed shifted right by 64-bits. If
/// this succeeds, getShiftedValue() will be called to produce the value.
static bool canEvaluateShifted(Value *V, unsigned ShAmt, bool IsLeftShift,
InstCombinerImpl &IC, Instruction *CxtI) {
APInt ShiftMask =
APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), ShAmt);
return refineEvaluableShiftMask(V, ShiftMask, IsLeftShift, IC, CxtI);
}

/// Fold OuterShift (InnerShift X, C1), C2.
/// See canEvaluateShiftedShift() for the constraints on these instructions.
static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
Expand Down Expand Up @@ -978,6 +1025,50 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
return new ZExtInst(Overflow, Ty);
}

/// If the operand \p Op of a zext-ed left shift \p I is a logically
/// right-shifted value, try to fold the opposing shifts.
static Instruction *foldShrThroughZExtedShl(BinaryOperator &I, Value *Op,
unsigned ShlAmt,
InstCombinerImpl &IC,
const DataLayout &DL) {
Type *DestTy = I.getType();
const unsigned InnerBitWidth = Op->getType()->getScalarSizeInBits();

// Determine if the operand is effectively right-shifted by counting the
// known leading zero bits.
KnownBits Known = IC.computeKnownBits(Op, nullptr);
const unsigned MaxInnerShrAmt = Known.countMinLeadingZeros();
if (MaxInnerShrAmt == 0)
return nullptr;
APInt ShrMask =
APInt::getLowBitsSet(InnerBitWidth, std::min(MaxInnerShrAmt, ShlAmt) + 1);

// Undo the maximal inner right shift amount that simplifies the overall
// computation.
if (!refineEvaluableShiftMask(Op, ShrMask, /*IsLeftShift=*/true, IC, nullptr))
return nullptr;

const unsigned InnerShrAmt = ShrMask.getActiveBits() - 1;
if (InnerShrAmt == 0)
return nullptr;
assert(InnerShrAmt <= ShlAmt);

const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt;
Value *NewOp = getShiftedValue(Op, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
if (ReducedShlAmt == 0)
return new ZExtInst(NewOp, DestTy);

Value *NewZExt = IC.Builder.CreateZExt(NewOp, DestTy);
NewZExt->takeName(I.getOperand(0));
auto *NewShl = BinaryOperator::CreateShl(
NewZExt, ConstantInt::get(DestTy, ReducedShlAmt));

// New shl inherits all flags from the original shl instruction.
NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
return NewShl;
}

// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
assert(I.isShift() && "Expected a shift as input");
Expand Down Expand Up @@ -1062,14 +1153,18 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
if (match(Op1, m_APInt(C))) {
unsigned ShAmtC = C->getZExtValue();

// shl (zext X), C --> zext (shl X, C)
// This is only valid if X would have zeros shifted out.
Value *X;
if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) {
// shl (zext X), C --> zext (shl X, C)
// This is only valid if X would have zeros shifted out.
unsigned SrcWidth = X->getType()->getScalarSizeInBits();
if (ShAmtC < SrcWidth &&
MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), &I))
return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty);

// Otherwise, try to cancel the outer shl with a lshr inside the zext.
if (Instruction *V = foldShrThroughZExtedShl(I, X, ShAmtC, *this, DL))
return V;
}

// (X >> C) << C --> X & (-1 << C)
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ define void @numsignbits_shl_zext_extended_bits_remains(i8 %x) {
define void @numsignbits_shl_zext_all_bits_shifted_out(i8 %x) {
; CHECK-LABEL: define void @numsignbits_shl_zext_all_bits_shifted_out(
; CHECK-SAME: i8 [[X:%.*]]) {
; CHECK-NEXT: [[ASHR:%.*]] = lshr i8 [[X]], 5
; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i8 [[ASHR]] to i16
; CHECK-NEXT: [[NSB1:%.*]] = shl i16 [[ZEXT]], 14
; CHECK-NEXT: [[ASHR:%.*]] = and i8 [[X]], 96
; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i8 [[ASHR]] to i16
; CHECK-NEXT: [[NSB1:%.*]] = shl nuw i16 [[TMP1]], 9
; CHECK-NEXT: [[AND14:%.*]] = and i16 [[NSB1]], 16384
; CHECK-NEXT: [[ADD14:%.*]] = add i16 [[AND14]], [[NSB1]]
; CHECK-NEXT: call void @escape(i16 [[ADD14]])
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/Transforms/InstCombine/iX-ext-split.ll
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ define i128 @i128_ext_split_neg4(i32 %x) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[LOWERSRC:%.*]] = sext i32 [[X]] to i64
; CHECK-NEXT: [[LO:%.*]] = zext i64 [[LOWERSRC]] to i128
; CHECK-NEXT: [[SIGN:%.*]] = lshr i32 [[X]], 31
; CHECK-NEXT: [[WIDEN:%.*]] = zext nneg i32 [[SIGN]] to i128
; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[WIDEN]], 64
; CHECK-NEXT: [[SIGN:%.*]] = and i32 [[X]], -2147483648
; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[SIGN]] to i128
; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[TMP0]], 33
; CHECK-NEXT: [[RES:%.*]] = or disjoint i128 [[HI]], [[LO]]
; CHECK-NEXT: ret i128 [[RES]]
;
Expand Down
Loading
Loading