-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[DemandedBits] Support non-constant shift amounts #148880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This is done by supporting shift operators to handle non constant shift amount.
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-llvm-analysis Author: Panagiotis K (karouzakisp) ChangesThis is part of a larger PR: #148853 Here we add support to the shift operators to handle non-constant shift operands. Full diff: https://github.com/llvm/llvm-project/pull/148880.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp
index 6694d5cc06c8c..2d30575c19130 100644
--- a/llvm/lib/Analysis/DemandedBits.cpp
+++ b/llvm/lib/Analysis/DemandedBits.cpp
@@ -36,6 +36,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstdint>
@@ -183,6 +184,17 @@ void DemandedBits::determineLiveOperandBits(
AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
else if (S->hasNoUnsignedWrap())
AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
+ } else {
+ ComputeKnownBits(BitWidth, UserI->getOperand(1), nullptr);
+ unsigned Min = Known.getMinValue().getLimitedValue(BitWidth - 1);
+ unsigned Max = Known.getMaxValue().getLimitedValue(BitWidth - 1);
+ // similar to Lshr case
+ AB = (AOut.lshr(Min) | AOut.lshr(Max));
+ const auto *S = cast<ShlOperator>(UserI);
+ if (S->hasNoSignedWrap())
+ AB |= APInt::getHighBitsSet(BitWidth, Max + 1);
+ else if (S->hasNoUnsignedWrap())
+ AB |= APInt::getHighBitsSet(BitWidth, Max);
}
}
break;
@@ -197,6 +209,19 @@ void DemandedBits::determineLiveOperandBits(
// (they must be zero).
if (cast<LShrOperator>(UserI)->isExact())
AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+ } else {
+ ComputeKnownBits(BitWidth, UserI->getOperand(1), nullptr);
+ unsigned Min = Known.getMinValue().getLimitedValue(BitWidth - 1);
+ unsigned Max = Known.getMaxValue().getLimitedValue(BitWidth - 1);
+ // Suppose AOut == 0b0000 0011
+ // [min, max] = [1, 3]
+ // shift by 1 we get 0b0000 0110
+ // shift by 2 we get 0b0000 1100
+ // shift by 3 we get 0b0001 1000
+ // we take the or here because need to cover all the above possibilities
+ AB = (AOut.shl(Min) | AOut.shl(Max));
+ if (cast<LShrOperator>(UserI)->isExact())
+ AB |= APInt::getLowBitsSet(BitWidth, Max);
}
}
break;
@@ -217,6 +242,27 @@ void DemandedBits::determineLiveOperandBits(
// (they must be zero).
if (cast<AShrOperator>(UserI)->isExact())
AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+ } else {
+ ComputeKnownBits(BitWidth, UserI->getOperand(1), nullptr);
+ unsigned Min = Known.getMinValue().getLimitedValue(BitWidth - 1);
+ unsigned Max = Known.getMaxValue().getLimitedValue(BitWidth - 1);
+ AB = (AOut.shl(Min) | AOut.shl(Max));
+
+ if (Max) {
+ // Suppose AOut = 0011 1100
+ // [min, max] = [1, 3]
+ // ShiftAmount = 1 : Mask is 1000 0000
+ // ShiftAmount = 2 : Mask is 1100 0000
+ // ShiftAmount = 3 : Mask is 1110 0000
+ // The Mask with Max covers every case in [min, max],
+ // so we are done
+ if ((AOut & APInt::getHighBitsSet(BitWidth, Max)).getBoolValue())
+ AB.setSignBit();
+ }
+ // If the shift is exact, then the low bits are not dead
+ // (they must be zero).
+ if (cast<AShrOperator>(UserI)->isExact())
+ AB |= APInt::getLowBitsSet(BitWidth, Max);
}
}
break;
diff --git a/llvm/test/Analysis/DemandedBits/shl.ll b/llvm/test/Analysis/DemandedBits/shl.ll
index e41f5f4107735..c3313a93c1e85 100644
--- a/llvm/test/Analysis/DemandedBits/shl.ll
+++ b/llvm/test/Analysis/DemandedBits/shl.ll
@@ -57,10 +57,56 @@ define i8 @test_shl(i32 %a, i32 %b) {
; CHECK-DAG: DemandedBits: 0xff for %shl.t = trunc i32 %shl to i8
; CHECK-DAG: DemandedBits: 0xff for %shl in %shl.t = trunc i32 %shl to i8
; CHECK-DAG: DemandedBits: 0xff for %shl = shl i32 %a, %b
-; CHECK-DAG: DemandedBits: 0xffffffff for %a in %shl = shl i32 %a, %b
+; CHECK-DAG: DemandedBits: 0xff for %a in %shl = shl i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %b in %shl = shl i32 %a, %b
;
%shl = shl i32 %a, %b
%shl.t = trunc i32 %shl to i8
ret i8 %shl.t
}
+
+define i8 @test_shl_var_amount(i32 %a, i32 %b){
+; CHECK-LABEL: 'test_shl_var_amount'
+; CHECK-DAG: DemandedBits: 0xff for %5 = trunc i32 %4 to i8
+; CHECK-DAG: DemandedBits: 0xff for %4 in %5 = trunc i32 %4 to i8
+; CHECK-DAG: DemandedBits: 0xff for %4 = shl i32 %1, %3
+; CHECK-DAG: DemandedBits: 0xff for %1 in %4 = shl i32 %1, %3
+; CHECK-DAG: DemandedBits: 0xffffffff for %3 in %4 = shl i32 %1, %3
+; CHECK-DAG: DemandedBits: 0xff for %2 = trunc i32 %1 to i8
+; CHECK-DAG: DemandedBits: 0xff for %1 in %2 = trunc i32 %1 to i8
+; CHECK-DAG: DemandedBits: 0xffffffff for %3 = zext i8 %2 to i32
+; CHECK-DAG: DemandedBits: 0xff for %2 in %3 = zext i8 %2 to i32
+; CHECK-DAG: DemandedBits: 0xff for %1 = add nsw i32 %a, %b
+; CHECK-DAG: DemandedBits: 0xff for %a in %1 = add nsw i32 %a, %b
+; CHECK-DAG: DemandedBits: 0xff for %b in %1 = add nsw i32 %a, %b
+;
+ %1 = add nsw i32 %a, %b
+ %2 = trunc i32 %1 to i8
+ %3 = zext i8 %2 to i32
+ %4 = shl i32 %1, %3
+ %5 = trunc i32 %4 to i8
+ ret i8 %5
+}
+
+define i8 @test_shl_var_amount_nsw(i32 %a, i32 %b){
+ ; CHECK-LABEL 'test_shl_var_amount_nsw'
+ ; CHECK-DAG: DemandedBits: 0xff for %5 = trunc i32 %4 to i8
+ ; CHECK-DAG: DemandedBits: 0xff for %4 in %5 = trunc i32 %4 to i8
+ ; CHECK-DAG: DemandedBits: 0xff for %4 = shl nsw i32 %1, %3
+ ; CHECK-DAG: DemandedBits: 0xffffffff for %1 in %4 = shl nsw i32 %1, %3
+ ; CHECK-DAG: DemandedBits: 0xffffffff for %3 in %4 = shl nsw i32 %1, %3
+ ; CHECK-DAG: DemandedBits: 0xffffffff for %3 = zext i8 %2 to i32
+ ; CHECK-DAG: DemandedBits: 0xff for %2 in %3 = zext i8 %2 to i32
+ ; CHECK-DAG: DemandedBits: 0xff for %2 = trunc i32 %1 to i8
+ ; CHECK-DAG: DemandedBits: 0xff for %1 in %2 = trunc i32 %1 to i8
+ ; CHECK-DAG: DemandedBits: 0xffffffff for %1 = add nsw i32 %a, %b
+ ; CHECK-DAG: DemandedBits: 0xffffffff for %a in %1 = add nsw i32 %a, %b
+ ; CHECK-DAG: DemandedBits: 0xffffffff for %b in %1 = add nsw i32 %a, %b
+ ;
+ %1 = add nsw i32 %a, %b
+ %2 = trunc i32 %1 to i8
+ %3 = zext i8 %2 to i32
+ %4 = shl nsw i32 %1, %3
+ %5 = trunc i32 %4 to i8
+ ret i8 %5
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing coverage for lshr and ashr? Could you kindly add tests for them?
Missing tests for right shifts? |
Kindly note that we only have a squash-and-merge. As a result:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please provide the alive2 proof. See also my previous comment #148853 (review)
llvm/lib/Analysis/DemandedBits.cpp
Outdated
// shift by 2 we get 0b0000 1100 | ||
// shift by 3 we get 0b0001 1000 | ||
// we take the or here because need to cover all the above possibilities | ||
AB = (AOut.shl(Min) | AOut.shl(Max)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this need to be the OR of all possible shift amounts between Min and Max? Not just the end points. Using the end points only works if the set bits in AOut are contiguous.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this need to be the OR of all possible shift amounts between Min and Max? Not just the end points. Using the end points only works if the set bits in AOut are contiguous.
Yes that's correct. I just added a function GetShiftedRange to shift between Min and Max
…dle non continued bits for AOut
I just added the tests |
I am not certain which transformation I should verify. Maybe the one on your previous comment? |
I think what we want verified is the algorithm of the analysis itself, not a particular transformation: if can express the code you wrote for DemandedBits in a language that Alive2 can verify, that would be great (this isn't exactly straight-forward, but @dtcxzyw left some hints). Think about it, and try it out: we'll help out. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Miscompilation reproducer: https://alive2.llvm.org/ce/z/bSBzWM
; bin/opt -passes=bdce test.ll -S
define i16 @src(i32 range(i32 0, 2) %x) {
entry:
%or = or i32 0, 48
%shl = shl i32 %or, %x
%trunc = trunc i32 %shl to i16
ret i16 %trunc
}
define i16 @tgt(i32 range(i32 0, 2) %x) {
entry:
%shl = shl i32 0, %x
%trunc = trunc i32 %shl to i16
ret i16 %trunc
}
Fixed, Alive verifications coming soon. Hopefully this week! |
@dtcxzyw Here are the alive2 proofs --> Please note that since my transformation contains a loop and the Alive syntax doesn't permit loops, I added various ranges. Please let me know if it's okay. @artagnon, Please let me know what you think. |
You can use a smaller integer bitwidth (e.g., i4/i8), then unroll the loop with |
Thanks for the tip. Here is the updated proof --> https://alive2.llvm.org/ce/z/tCvUT6 |
In your proof, the range of shamt is not taken into account. Updated: https://alive2.llvm.org/ce/z/n4hgkX |
@@ -76,6 +76,16 @@ void DemandedBits::determineLiveOperandBits( | |||
computeKnownBits(V2, Known2, DL, &AC, UserI, &DT); | |||
} | |||
}; | |||
auto GetShiftedRange = [&](unsigned const Min, unsigned const Max, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto GetShiftedRange = [&](unsigned const Min, unsigned const Max, | |
auto GetShiftedRange = [&](unsigned Min, unsigned Max, |
unsigned Max = Known.getMaxValue().getLimitedValue(BitWidth - 1); | ||
// Suppose AOut == 0b0000 1001 | ||
// [min, max] = [1, 3] | ||
// shift by 1 we get 0b0001 00100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// shift by 1 we get 0b0001 00100 | |
// shift by 1 we get 0b0001 0010 |
This patch adds support for the shift operators to handle non-constant shift operands.