Skip to content

Don't perform unsigned comparisons for signed integers #124122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions compiler/rustc_mir_transform/src/match_branches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
return None;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it make sense to compare the lengths of the basic blocks...?!?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we are currently only merging simple BBs, we have not considered BBs that could potentially be merged even though they have different numbers of instructions.

Copy link
Member

@RalfJung RalfJung Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's a "simple BB"? If both BBs have 24 instructions, how is that okay but one having 20 and the other 24 is not? The same number of instructions tells you absolutely nothing about what the BBs are doing...

There needs to be a comment here, at first sight this seems entirely nonsensical.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the unclear expression. I mean the merging of basic blocks that only consider simple scenarios.
This is just an early bail-out. We assume that BBs with different the lengths of BBs cannot be merged.


// For signed comparisons, we need to consider different bit widths,
// so we need to transform to i128 for comparison.
fn int_equal(l: ScalarInt, r: impl Into<u128>, size: Size) -> bool {
l.try_to_int(l.size()).unwrap()
== ScalarInt::try_from_uint(r, size).unwrap().try_to_int(size).unwrap()
Expand Down Expand Up @@ -399,7 +401,10 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
Copy link
Member

@RalfJung RalfJung Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
if ((f_c.const_.ty().is_signed() && discr_ty.is_signed())

Otherwise you're still potentially treating something as signed that is unsigned, or vice versa.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I think it's best to move this entire thing into a helper, not just int_equal.

And also, why is int_equal so different from what happens for unsigned? They should be entirely identical except for using int vs uint functions.

Copy link
Member

@RalfJung RalfJung Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, it's about converting the u128 from the SwitchInt to a ScalarInt. But that should be uniform -- do it once before the comparison. The input is not sign extended so it's always SwitchInt::try_from_uint.

So already further up, you should convert first_val and second_val to ScalarInt.

In fact we should probably change SwitchInt to store a ScalarInt rather than a raw u128, or at least make it easy to get the SwitchTarget values as ScalarInt.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold on a sec... isn't what you actually want to do here some sort of cast? I don't know the right direction, but -- basically you want to cast the discriminant value to the type of the constant (or the other way around), and then check they are equal, right?

The interpreter has the int_to_int_or_float method for that. But really you just care about this match arm. That should probably be turned into a helper somewhere it can be used by mir-opts.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise you're still potentially treating something as signed that is unsigned, or vice versa.

Everything looks fine here. Any signed integer will be converted to a signed comparison.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact we should probably change SwitchInt to store a ScalarInt rather than a raw u128, or at least make it easy to get the SwitchTarget values as ScalarInt.

I have seen your new issue. :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interpreter has the int_to_int_or_float method for that. But really you just care about this match arm. That should probably be turned into a helper somewhere it can be used by mir-opts.

Perhaps this could be a separate PR?

Copy link
Member

@RalfJung RalfJung Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks fine here. Any signed integer will be converted to a signed comparison.

You're treating both values as signed if either of them is signed. That means you can be treating unsigned values as signed, which is wrong.

Perhaps this could be a separate PR?

Perhaps, but I don't understand you current PR, so it may also be a way to turn this code into something that makes sense to more than one person. ;)

Feel free to pick a different reviewer, but I can't make sense of what this code is trying to achieve. The comments don't explain the high-level picture (what are we even trying to achieve with this complicated series of checks) and the low-level details are clearly still mixing up signedness.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks fine here. Any signed integer will be converted to a signed comparison.

You're treating both values as signed if either of them is signed. That means you can be treating unsigned values as signed, which is wrong.

In the known test cases, it is correct. But I must carefully check the edge cases here. For safety reasons, I will later consider only signed-to-signed conversions in this PR. :)

cc @rust-lang/wg-mir-opt Perhaps someone else will directly point out this specific problem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the known test cases, it is correct.

That's a very low bar. The comments should give convincing reasons why it is correct for all possible MIR ever.

&& int_equal(f, first_val, discr_size)
&& int_equal(s, second_val, discr_size))
|| (Some(f) == ScalarInt::try_from_uint(first_val, f.size())
|| (!f_c.const_.ty().is_signed()
&& !discr_ty.is_signed()
&& Some(f)
== ScalarInt::try_from_uint(first_val, f.size())
&& Some(s)
== ScalarInt::try_from_uint(second_val, s.size())) =>
{
Expand Down Expand Up @@ -449,7 +454,10 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
{
continue;
}
if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) {
if !is_signed
&& !s_c.const_.ty().is_signed()
&& Some(f) == ScalarInt::try_from_uint(other_val, f.size())
{
continue;
}
return None;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,37 @@
debug i => _1;
let mut _0: u128;
let mut _2: i128;
+ let mut _3: i128;

bb0: {
_2 = discriminant(_1);
- switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb2, otherwise: bb1];
- }
-
- bb1: {
- unreachable;
- }
-
- bb2: {
- _0 = const core::num::<impl u128>::MAX;
- goto -> bb6;
- }
-
- bb3: {
- _0 = const 1_u128;
- goto -> bb6;
- }
-
- bb4: {
- _0 = const 2_u128;
- goto -> bb6;
- }
-
- bb5: {
- _0 = const 3_u128;
- goto -> bb6;
- }
-
- bb6: {
+ StorageLive(_3);
+ _3 = move _2;
+ _0 = _3 as u128 (IntToInt);
+ StorageDead(_3);
switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb2, otherwise: bb1];
}

bb1: {
unreachable;
}

bb2: {
_0 = const core::num::<impl u128>::MAX;
goto -> bb6;
}

bb3: {
_0 = const 1_u128;
goto -> bb6;
}

bb4: {
_0 = const 2_u128;
goto -> bb6;
}

bb5: {
_0 = const 3_u128;
goto -> bb6;
}

bb6: {
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
- // MIR for `match_i8_i16_failed_2_a` before MatchBranchSimplification
+ // MIR for `match_i8_i16_failed_2_a` after MatchBranchSimplification

fn match_i8_i16_failed_2_a(_1: EnumAi8) -> i16 {
debug i => _1;
let mut _0: i16;
let mut _2: i8;

bb0: {
_2 = discriminant(_1);
switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1];
}

bb1: {
unreachable;
}

bb2: {
_0 = const -3_i16;
goto -> bb5;
}

bb3: {
_0 = const 255_i16;
goto -> bb5;
}

bb4: {
_0 = const 2_i16;
goto -> bb5;
}

bb5: {
return;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
- // MIR for `match_i8_i16_failed_2_b` before MatchBranchSimplification
+ // MIR for `match_i8_i16_failed_2_b` after MatchBranchSimplification

fn match_i8_i16_failed_2_b(_1: EnumAi8) -> i16 {
debug i => _1;
let mut _0: i16;
let mut _2: i8;

bb0: {
_2 = discriminant(_1);
switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1];
}

bb1: {
unreachable;
}

bb2: {
_0 = const 253_i16;
goto -> bb5;
}

bb3: {
_0 = const -1_i16;
goto -> bb5;
}

bb4: {
_0 = const 2_i16;
goto -> bb5;
}

bb5: {
return;
}
}

30 changes: 27 additions & 3 deletions tests/mir-opt/matches_reduce_branches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,31 @@ fn match_i8_i16_failed(i: EnumAi8) -> i16 {
}
}

// We cannot transform it, even though `-1i8` and `255i16` are the same in terms of bits,
// what we actually require is that they are considered equal in a signed comparison,
// after sign-extending to the larger type.
// EMIT_MIR matches_reduce_branches.match_i8_i16_failed_2_a.MatchBranchSimplification.diff
fn match_i8_i16_failed_2_a(i: EnumAi8) -> i16 {
// CHECK-LABEL: fn match_i8_i16_failed_2_a(
// CHECK: switchInt
match i {
EnumAi8::A => 255,
EnumAi8::B => 2,
EnumAi8::C => -3,
}
}

// EMIT_MIR matches_reduce_branches.match_i8_i16_failed_2_b.MatchBranchSimplification.diff
fn match_i8_i16_failed_2_b(i: EnumAi8) -> i16 {
// CHECK-LABEL: fn match_i8_i16_failed_2_b(
// CHECK: switchInt
match i {
EnumAi8::A => -1,
EnumAi8::B => 2,
EnumAi8::C => 253,
}
}

#[repr(i16)]
enum EnumAi16 {
A = -1,
Expand Down Expand Up @@ -253,12 +278,11 @@ enum EnumAi128 {
D = -1,
}

// FIXME: This transform is reasonable.
// EMIT_MIR matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff
fn match_i128_u128(i: EnumAi128) -> u128 {
// CHECK-LABEL: fn match_i128_u128(
// CHECK-NOT: switchInt
// CHECK: _0 = _3 as u128 (IntToInt);
// CHECH: return
// CHECK: switchInt
match i {
EnumAi128::A => 1,
EnumAi128::B => 2,
Expand Down
Loading