Skip to content

Commit 1322f92

Browse files
committed
Auto merge of #107009 - cjgillot:jump-threading, r=pnkfelix
Implement jump threading MIR opt This pass is an attempt to generalize `ConstGoto` and `SeparateConstSwitch` passes into a more complete jump threading pass. This pass is rather heavy, as it performs a truncated backwards DFS on MIR starting from each `SwitchInt` terminator. This backwards DFS remains very limited, as it only walks through `Goto` terminators. It is build to support constants and discriminants, and a propagating through a very limited set of operations. The pass successfully manages to disentangle the `Some(x?)` use case and the DFA use case. It still needs a few tests before being ready.
2 parents e2068cd + dd08dd4 commit 1322f92

File tree

31 files changed

+2797
-137
lines changed

31 files changed

+2797
-137
lines changed

Cargo.lock

+1
Original file line numberDiff line numberDiff line change
@@ -4279,6 +4279,7 @@ dependencies = [
42794279
"coverage_test_macros",
42804280
"either",
42814281
"itertools",
4282+
"rustc_arena",
42824283
"rustc_ast",
42834284
"rustc_attr",
42844285
"rustc_const_eval",

compiler/rustc_middle/src/mir/terminator.rs

+9
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ impl SwitchTargets {
2828
Self { values: smallvec![value], targets: smallvec![then, else_] }
2929
}
3030

31+
/// Inverse of `SwitchTargets::static_if`.
32+
pub fn as_static_if(&self) -> Option<(u128, BasicBlock, BasicBlock)> {
33+
if let &[value] = &self.values[..] && let &[then, else_] = &self.targets[..] {
34+
Some((value, then, else_))
35+
} else {
36+
None
37+
}
38+
}
39+
3140
/// Returns the fallback target that is jumped to when none of the values match the operand.
3241
pub fn otherwise(&self) -> BasicBlock {
3342
*self.targets.last().unwrap()

compiler/rustc_mir_dataflow/src/value_analysis.rs

+124-25
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,19 @@ impl<V: Clone> Clone for State<V> {
463463
}
464464
}
465465

466-
impl<V: Clone + HasTop + HasBottom> State<V> {
466+
impl<V: Clone> State<V> {
467+
pub fn new(init: V, map: &Map) -> State<V> {
468+
let values = IndexVec::from_elem_n(init, map.value_count);
469+
State(StateData::Reachable(values))
470+
}
471+
472+
pub fn all(&self, f: impl Fn(&V) -> bool) -> bool {
473+
match self.0 {
474+
StateData::Unreachable => true,
475+
StateData::Reachable(ref values) => values.iter().all(f),
476+
}
477+
}
478+
467479
pub fn is_reachable(&self) -> bool {
468480
matches!(&self.0, StateData::Reachable(_))
469481
}
@@ -472,7 +484,10 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
472484
self.0 = StateData::Unreachable;
473485
}
474486

475-
pub fn flood_all(&mut self) {
487+
pub fn flood_all(&mut self)
488+
where
489+
V: HasTop,
490+
{
476491
self.flood_all_with(V::TOP)
477492
}
478493

@@ -481,28 +496,52 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
481496
values.raw.fill(value);
482497
}
483498

499+
/// Assign `value` to all places that are contained in `place` or may alias one.
484500
pub fn flood_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
485-
let StateData::Reachable(values) = &mut self.0 else { return };
486-
map.for_each_aliasing_place(place, None, &mut |vi| {
487-
values[vi] = value.clone();
488-
});
501+
self.flood_with_tail_elem(place, None, map, value)
489502
}
490503

491-
pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map) {
504+
/// Assign `TOP` to all places that are contained in `place` or may alias one.
505+
pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map)
506+
where
507+
V: HasTop,
508+
{
492509
self.flood_with(place, map, V::TOP)
493510
}
494511

512+
/// Assign `value` to the discriminant of `place` and all places that may alias it.
495513
pub fn flood_discr_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) {
496-
let StateData::Reachable(values) = &mut self.0 else { return };
497-
map.for_each_aliasing_place(place, Some(TrackElem::Discriminant), &mut |vi| {
498-
values[vi] = value.clone();
499-
});
514+
self.flood_with_tail_elem(place, Some(TrackElem::Discriminant), map, value)
500515
}
501516

502-
pub fn flood_discr(&mut self, place: PlaceRef<'_>, map: &Map) {
517+
/// Assign `TOP` to the discriminant of `place` and all places that may alias it.
518+
pub fn flood_discr(&mut self, place: PlaceRef<'_>, map: &Map)
519+
where
520+
V: HasTop,
521+
{
503522
self.flood_discr_with(place, map, V::TOP)
504523
}
505524

525+
/// This method is the most general version of the `flood_*` method.
526+
///
527+
/// Assign `value` on the given place and all places that may alias it. In particular, when
528+
/// the given place has a variant downcast, we invoke the function on all the other variants.
529+
///
530+
/// `tail_elem` allows to support discriminants that are not a place in MIR, but that we track
531+
/// as such.
532+
pub fn flood_with_tail_elem(
533+
&mut self,
534+
place: PlaceRef<'_>,
535+
tail_elem: Option<TrackElem>,
536+
map: &Map,
537+
value: V,
538+
) {
539+
let StateData::Reachable(values) = &mut self.0 else { return };
540+
map.for_each_aliasing_place(place, tail_elem, &mut |vi| {
541+
values[vi] = value.clone();
542+
});
543+
}
544+
506545
/// Low-level method that assigns to a place.
507546
/// This does nothing if the place is not tracked.
508547
///
@@ -553,44 +592,104 @@ impl<V: Clone + HasTop + HasBottom> State<V> {
553592
}
554593

555594
/// Helper method to interpret `target = result`.
556-
pub fn assign(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
595+
pub fn assign(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map)
596+
where
597+
V: HasTop,
598+
{
557599
self.flood(target, map);
558600
if let Some(target) = map.find(target) {
559601
self.insert_idx(target, result, map);
560602
}
561603
}
562604

563605
/// Helper method for assignments to a discriminant.
564-
pub fn assign_discr(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map) {
606+
pub fn assign_discr(&mut self, target: PlaceRef<'_>, result: ValueOrPlace<V>, map: &Map)
607+
where
608+
V: HasTop,
609+
{
565610
self.flood_discr(target, map);
566611
if let Some(target) = map.find_discr(target) {
567612
self.insert_idx(target, result, map);
568613
}
569614
}
570615

616+
/// Retrieve the value stored for a place, or `None` if it is not tracked.
617+
pub fn try_get(&self, place: PlaceRef<'_>, map: &Map) -> Option<V> {
618+
let place = map.find(place)?;
619+
self.try_get_idx(place, map)
620+
}
621+
622+
/// Retrieve the discriminant stored for a place, or `None` if it is not tracked.
623+
pub fn try_get_discr(&self, place: PlaceRef<'_>, map: &Map) -> Option<V> {
624+
let place = map.find_discr(place)?;
625+
self.try_get_idx(place, map)
626+
}
627+
628+
/// Retrieve the slice length stored for a place, or `None` if it is not tracked.
629+
pub fn try_get_len(&self, place: PlaceRef<'_>, map: &Map) -> Option<V> {
630+
let place = map.find_len(place)?;
631+
self.try_get_idx(place, map)
632+
}
633+
634+
/// Retrieve the value stored for a place index, or `None` if it is not tracked.
635+
pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option<V> {
636+
match &self.0 {
637+
StateData::Reachable(values) => {
638+
map.places[place].value_index.map(|v| values[v].clone())
639+
}
640+
StateData::Unreachable => None,
641+
}
642+
}
643+
571644
/// Retrieve the value stored for a place, or ⊤ if it is not tracked.
572-
pub fn get(&self, place: PlaceRef<'_>, map: &Map) -> V {
573-
map.find(place).map(|place| self.get_idx(place, map)).unwrap_or(V::TOP)
645+
///
646+
/// This method returns ⊥ if the place is tracked and the state is unreachable.
647+
pub fn get(&self, place: PlaceRef<'_>, map: &Map) -> V
648+
where
649+
V: HasBottom + HasTop,
650+
{
651+
match &self.0 {
652+
StateData::Reachable(_) => self.try_get(place, map).unwrap_or(V::TOP),
653+
// Because this is unreachable, we can return any value we want.
654+
StateData::Unreachable => V::BOTTOM,
655+
}
574656
}
575657

576658
/// Retrieve the value stored for a place, or ⊤ if it is not tracked.
577-
pub fn get_discr(&self, place: PlaceRef<'_>, map: &Map) -> V {
578-
match map.find_discr(place) {
579-
Some(place) => self.get_idx(place, map),
580-
None => V::TOP,
659+
///
660+
/// This method returns ⊥ the current state is unreachable.
661+
pub fn get_discr(&self, place: PlaceRef<'_>, map: &Map) -> V
662+
where
663+
V: HasBottom + HasTop,
664+
{
665+
match &self.0 {
666+
StateData::Reachable(_) => self.try_get_discr(place, map).unwrap_or(V::TOP),
667+
// Because this is unreachable, we can return any value we want.
668+
StateData::Unreachable => V::BOTTOM,
581669
}
582670
}
583671

584672
/// Retrieve the value stored for a place, or ⊤ if it is not tracked.
585-
pub fn get_len(&self, place: PlaceRef<'_>, map: &Map) -> V {
586-
match map.find_len(place) {
587-
Some(place) => self.get_idx(place, map),
588-
None => V::TOP,
673+
///
674+
/// This method returns ⊥ the current state is unreachable.
675+
pub fn get_len(&self, place: PlaceRef<'_>, map: &Map) -> V
676+
where
677+
V: HasBottom + HasTop,
678+
{
679+
match &self.0 {
680+
StateData::Reachable(_) => self.try_get_len(place, map).unwrap_or(V::TOP),
681+
// Because this is unreachable, we can return any value we want.
682+
StateData::Unreachable => V::BOTTOM,
589683
}
590684
}
591685

592686
/// Retrieve the value stored for a place index, or ⊤ if it is not tracked.
593-
pub fn get_idx(&self, place: PlaceIndex, map: &Map) -> V {
687+
///
688+
/// This method returns ⊥ the current state is unreachable.
689+
pub fn get_idx(&self, place: PlaceIndex, map: &Map) -> V
690+
where
691+
V: HasBottom + HasTop,
692+
{
594693
match &self.0 {
595694
StateData::Reachable(values) => {
596695
map.places[place].value_index.map(|v| values[v].clone()).unwrap_or(V::TOP)

compiler/rustc_mir_transform/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ smallvec = { version = "1.8.1", features = ["union", "may_dangle"] }
1111
tracing = "0.1"
1212
either = "1"
1313
rustc_ast = { path = "../rustc_ast" }
14+
rustc_arena = { path = "../rustc_arena" }
1415
rustc_attr = { path = "../rustc_attr" }
1516
rustc_data_structures = { path = "../rustc_data_structures" }
1617
rustc_errors = { path = "../rustc_errors" }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
use rustc_middle::mir::visit::*;
2+
use rustc_middle::mir::*;
3+
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
4+
5+
const INSTR_COST: usize = 5;
6+
const CALL_PENALTY: usize = 25;
7+
const LANDINGPAD_PENALTY: usize = 50;
8+
const RESUME_PENALTY: usize = 45;
9+
10+
/// Verify that the callee body is compatible with the caller.
11+
#[derive(Clone)]
12+
pub(crate) struct CostChecker<'b, 'tcx> {
13+
tcx: TyCtxt<'tcx>,
14+
param_env: ParamEnv<'tcx>,
15+
cost: usize,
16+
callee_body: &'b Body<'tcx>,
17+
instance: Option<ty::Instance<'tcx>>,
18+
}
19+
20+
impl<'b, 'tcx> CostChecker<'b, 'tcx> {
21+
pub fn new(
22+
tcx: TyCtxt<'tcx>,
23+
param_env: ParamEnv<'tcx>,
24+
instance: Option<ty::Instance<'tcx>>,
25+
callee_body: &'b Body<'tcx>,
26+
) -> CostChecker<'b, 'tcx> {
27+
CostChecker { tcx, param_env, callee_body, instance, cost: 0 }
28+
}
29+
30+
pub fn cost(&self) -> usize {
31+
self.cost
32+
}
33+
34+
fn instantiate_ty(&self, v: Ty<'tcx>) -> Ty<'tcx> {
35+
if let Some(instance) = self.instance {
36+
instance.instantiate_mir(self.tcx, ty::EarlyBinder::bind(&v))
37+
} else {
38+
v
39+
}
40+
}
41+
}
42+
43+
impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
44+
fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
45+
// Don't count StorageLive/StorageDead in the inlining cost.
46+
match statement.kind {
47+
StatementKind::StorageLive(_)
48+
| StatementKind::StorageDead(_)
49+
| StatementKind::Deinit(_)
50+
| StatementKind::Nop => {}
51+
_ => self.cost += INSTR_COST,
52+
}
53+
}
54+
55+
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) {
56+
let tcx = self.tcx;
57+
match terminator.kind {
58+
TerminatorKind::Drop { ref place, unwind, .. } => {
59+
// If the place doesn't actually need dropping, treat it like a regular goto.
60+
let ty = self.instantiate_ty(place.ty(self.callee_body, tcx).ty);
61+
if ty.needs_drop(tcx, self.param_env) {
62+
self.cost += CALL_PENALTY;
63+
if let UnwindAction::Cleanup(_) = unwind {
64+
self.cost += LANDINGPAD_PENALTY;
65+
}
66+
} else {
67+
self.cost += INSTR_COST;
68+
}
69+
}
70+
TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => {
71+
let fn_ty = self.instantiate_ty(f.const_.ty());
72+
self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) {
73+
// Don't give intrinsics the extra penalty for calls
74+
INSTR_COST
75+
} else {
76+
CALL_PENALTY
77+
};
78+
if let UnwindAction::Cleanup(_) = unwind {
79+
self.cost += LANDINGPAD_PENALTY;
80+
}
81+
}
82+
TerminatorKind::Assert { unwind, .. } => {
83+
self.cost += CALL_PENALTY;
84+
if let UnwindAction::Cleanup(_) = unwind {
85+
self.cost += LANDINGPAD_PENALTY;
86+
}
87+
}
88+
TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY,
89+
TerminatorKind::InlineAsm { unwind, .. } => {
90+
self.cost += INSTR_COST;
91+
if let UnwindAction::Cleanup(_) = unwind {
92+
self.cost += LANDINGPAD_PENALTY;
93+
}
94+
}
95+
_ => self.cost += INSTR_COST,
96+
}
97+
}
98+
}

0 commit comments

Comments
 (0)