Skip to content

Commit 98fcb15

Browse files
Collect relevant item bounds from trait clauses for nested rigid projections, GATs
1 parent 7606c13 commit 98fcb15

File tree

4 files changed

+278
-10
lines changed

4 files changed

+278
-10
lines changed

compiler/rustc_hir_analysis/src/collect/item_bounds.rs

+208-10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use super::ItemCtxt;
22
use crate::astconv::{AstConv, PredicateFilter};
3+
use rustc_data_structures::fx::FxIndexMap;
34
use rustc_hir as hir;
45
use rustc_infer::traits::util;
5-
use rustc_middle::ty::GenericArgs;
6-
use rustc_middle::ty::{self, Ty, TyCtxt, TypeFoldable, TypeFolder};
6+
use rustc_middle::ty::fold::shift_vars;
7+
use rustc_middle::ty::{self, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeVisitableExt};
8+
use rustc_middle::ty::{GenericArgs, ToPredicate, TypeSuperFoldable};
79
use rustc_span::def_id::{DefId, LocalDefId};
810
use rustc_span::Span;
911

@@ -34,14 +36,99 @@ fn associated_type_bounds<'tcx>(
3436
let trait_def_id = tcx.local_parent(assoc_item_def_id);
3537
let trait_predicates = tcx.trait_explicit_predicates_and_bounds(trait_def_id);
3638

37-
let bounds_from_parent = trait_predicates.predicates.iter().copied().filter(|(pred, _)| {
38-
match pred.kind().skip_binder() {
39-
ty::ClauseKind::Trait(tr) => tr.self_ty() == item_ty,
40-
ty::ClauseKind::Projection(proj) => proj.projection_ty.self_ty() == item_ty,
41-
ty::ClauseKind::TypeOutlives(outlives) => outlives.0 == item_ty,
42-
_ => false,
43-
}
44-
});
39+
let item_trait_ref = ty::TraitRef::identity(tcx, tcx.parent(assoc_item_def_id.to_def_id()));
40+
let bounds_from_parent =
41+
trait_predicates.predicates.iter().copied().filter_map(|(pred, span)| {
42+
let mut clause_ty = match pred.kind().skip_binder() {
43+
ty::ClauseKind::Trait(tr) => tr.self_ty(),
44+
ty::ClauseKind::Projection(proj) => proj.projection_ty.self_ty(),
45+
ty::ClauseKind::TypeOutlives(outlives) => outlives.0,
46+
_ => return None,
47+
};
48+
49+
// The code below is quite involved, so let me explain.
50+
//
51+
// We loop here, because we also want to collect vars for nested associated items as
52+
// well. For example, given a clause like `Self::A::B`, we want to add that to the
53+
// item bounds for `A`, so that we may use that bound in the case that `Self::A::B` is
54+
// rigid.
55+
//
56+
// Secondly, regarding bound vars, when we see a where clause that mentions a GAT
57+
// like `for<'a, ...> Self::Assoc<'a, ...>: Bound<'b, ...>`, we want to turn that into
58+
// an item bound on the GAT, where all of the GAT args are substituted with the GAT's
59+
// param regions, and then keep all of the other late-bound vars in the bound around.
60+
// We need to "compress" the binder so that it doesn't mention any of those vars that
61+
// were mapped to params.
62+
let gat_vars = loop {
63+
if let ty::Alias(ty::Projection, alias_ty) = *clause_ty.kind() {
64+
if alias_ty.trait_ref(tcx) == item_trait_ref {
65+
break &alias_ty.args[item_trait_ref.args.len()..];
66+
} else {
67+
clause_ty = alias_ty.self_ty();
68+
continue;
69+
}
70+
}
71+
72+
return None;
73+
};
74+
// Special-case: No GAT vars, no mapping needed.
75+
if gat_vars.is_empty() {
76+
return Some((pred, span));
77+
}
78+
79+
// First, check that all of the GAT args are substituted with a unique late-bound arg.
80+
// If we find a duplicate, then it can't be mapped to the definition's params.
81+
let mut mapping = FxIndexMap::default();
82+
let generics = tcx.generics_of(assoc_item_def_id);
83+
for (param, var) in std::iter::zip(&generics.params, gat_vars) {
84+
let existing = match var.unpack() {
85+
ty::GenericArgKind::Lifetime(re) => {
86+
if let ty::RegionKind::ReBound(ty::INNERMOST, bv) = re.kind() {
87+
mapping.insert(bv.var, tcx.mk_param_from_def(param))
88+
} else {
89+
return None;
90+
}
91+
}
92+
ty::GenericArgKind::Type(ty) => {
93+
if let ty::Bound(ty::INNERMOST, bv) = *ty.kind() {
94+
mapping.insert(bv.var, tcx.mk_param_from_def(param))
95+
} else {
96+
return None;
97+
}
98+
}
99+
ty::GenericArgKind::Const(ct) => {
100+
if let ty::ConstKind::Bound(ty::INNERMOST, bv) = ct.kind() {
101+
mapping.insert(bv, tcx.mk_param_from_def(param))
102+
} else {
103+
return None;
104+
}
105+
}
106+
};
107+
108+
if existing.is_some() {
109+
return None;
110+
}
111+
}
112+
113+
// Finally, map all of the args in the GAT to the params we expect, and compress
114+
// the remaining late-bound vars so that they count up from var 0.
115+
let mut folder = MapAndCompressBoundVars {
116+
tcx,
117+
binder: ty::INNERMOST,
118+
still_bound_vars: vec![],
119+
mapping,
120+
};
121+
let pred = pred.kind().skip_binder().fold_with(&mut folder);
122+
123+
Some((
124+
ty::Binder::bind_with_vars(
125+
pred,
126+
tcx.mk_bound_variable_kinds(&folder.still_bound_vars),
127+
)
128+
.to_predicate(tcx),
129+
span,
130+
))
131+
});
45132

46133
let all_bounds = tcx.arena.alloc_from_iter(bounds.clauses().chain(bounds_from_parent));
47134
debug!(
@@ -52,6 +139,117 @@ fn associated_type_bounds<'tcx>(
52139
all_bounds
53140
}
54141

142+
struct MapAndCompressBoundVars<'tcx> {
143+
tcx: TyCtxt<'tcx>,
144+
/// How deep are we? Makes sure we don't touch the vars of nested binders.
145+
binder: ty::DebruijnIndex,
146+
/// List of bound vars that remain unsubstituted because they were not
147+
/// mentioned in the GAT's args.
148+
still_bound_vars: Vec<ty::BoundVariableKind>,
149+
/// Subtle invariant: If the `GenericArg` is bound, then it should be
150+
/// stored with the debruijn index of `INNERMOST` so it can be shifted
151+
/// correctly during substitution.
152+
mapping: FxIndexMap<ty::BoundVar, ty::GenericArg<'tcx>>,
153+
}
154+
155+
impl<'tcx> TypeFolder<TyCtxt<'tcx>> for MapAndCompressBoundVars<'tcx> {
156+
fn interner(&self) -> TyCtxt<'tcx> {
157+
self.tcx
158+
}
159+
160+
fn fold_binder<T>(&mut self, t: ty::Binder<'tcx, T>) -> ty::Binder<'tcx, T>
161+
where
162+
ty::Binder<'tcx, T>: TypeSuperFoldable<TyCtxt<'tcx>>,
163+
{
164+
self.binder.shift_in(1);
165+
let out = t.super_fold_with(self);
166+
self.binder.shift_out(1);
167+
out
168+
}
169+
170+
fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
171+
if !ty.has_bound_vars() {
172+
return ty;
173+
}
174+
175+
if let ty::Bound(binder, old_bound) = *ty.kind()
176+
&& self.binder == binder
177+
{
178+
let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
179+
mapped.expect_ty()
180+
} else {
181+
// If we didn't find a mapped generic, then make a new one.
182+
// Allocate a new var idx, and insert a new bound ty.
183+
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
184+
self.still_bound_vars.push(ty::BoundVariableKind::Ty(old_bound.kind));
185+
let mapped = Ty::new_bound(
186+
self.tcx,
187+
ty::INNERMOST,
188+
ty::BoundTy { var, kind: old_bound.kind },
189+
);
190+
self.mapping.insert(old_bound.var, mapped.into());
191+
mapped
192+
};
193+
194+
shift_vars(self.tcx, mapped, self.binder.as_u32())
195+
} else {
196+
ty.super_fold_with(self)
197+
}
198+
}
199+
200+
fn fold_region(&mut self, re: ty::Region<'tcx>) -> ty::Region<'tcx> {
201+
if let ty::ReBound(binder, old_bound) = re.kind()
202+
&& self.binder == binder
203+
{
204+
let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
205+
mapped.expect_region()
206+
} else {
207+
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
208+
self.still_bound_vars.push(ty::BoundVariableKind::Region(old_bound.kind));
209+
let mapped = ty::Region::new_bound(
210+
self.tcx,
211+
ty::INNERMOST,
212+
ty::BoundRegion { var, kind: old_bound.kind },
213+
);
214+
self.mapping.insert(old_bound.var, mapped.into());
215+
mapped
216+
};
217+
218+
shift_vars(self.tcx, mapped, self.binder.as_u32())
219+
} else {
220+
re
221+
}
222+
}
223+
224+
fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
225+
if !ct.has_bound_vars() {
226+
return ct;
227+
}
228+
229+
if let ty::ConstKind::Bound(binder, old_var) = ct.kind()
230+
&& self.binder == binder
231+
{
232+
let mapped = if let Some(mapped) = self.mapping.get(&old_var) {
233+
mapped.expect_const()
234+
} else {
235+
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
236+
self.still_bound_vars.push(ty::BoundVariableKind::Const);
237+
let mapped = ty::Const::new_bound(self.tcx, ty::INNERMOST, var, ct.ty());
238+
self.mapping.insert(old_var, mapped.into());
239+
mapped
240+
};
241+
242+
shift_vars(self.tcx, mapped, self.binder.as_u32())
243+
} else {
244+
ct.super_fold_with(self)
245+
}
246+
}
247+
248+
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
249+
if !p.has_bound_vars() { p } else { p.super_fold_with(self) }
250+
}
251+
}
252+
55253
/// Opaque types don't inherit bounds from their parent: for return position
56254
/// impl trait it isn't possible to write a suitable predicate on the
57255
/// containing function and for type-alias impl trait we don't have a backwards
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//@ check-pass
2+
//@ revisions: current next
3+
//@[next] compile-flags: -Znext-solver
4+
5+
trait Trait
6+
where
7+
Self::Assoc: Clone,
8+
{
9+
type Assoc;
10+
}
11+
12+
fn foo<T: Trait>(x: &T::Assoc) -> T::Assoc {
13+
x.clone()
14+
}
15+
16+
trait Trait2
17+
where
18+
Self::Assoc: Iterator,
19+
<Self::Assoc as Iterator>::Item: Clone,
20+
{
21+
type Assoc;
22+
}
23+
24+
fn foo2<T: Trait2>(x: &<T::Assoc as Iterator>::Item) -> <T::Assoc as Iterator>::Item {
25+
x.clone()
26+
}
27+
28+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//@ check-pass
2+
3+
// Test that `for<'a> Self::Gat<'a>: Debug` is implied in the definition of `Foo`,
4+
// just as it would be if it weren't a GAT but just a regular associated type.
5+
6+
use std::fmt::Debug;
7+
8+
trait Foo
9+
where
10+
for<'a> Self::Gat<'a>: Debug,
11+
{
12+
type Gat<'a>;
13+
}
14+
15+
fn test<T: Foo>(x: T::Gat<'static>) {
16+
println!("{:?}", x);
17+
}
18+
19+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//@ check-pass
2+
//@ revisions: current next
3+
//@[next] compile-flags: -Znext-solver
4+
5+
trait Foo
6+
where
7+
Self::Iterator: Iterator,
8+
<Self::Iterator as Iterator>::Item: Bar,
9+
{
10+
type Iterator;
11+
12+
fn iter() -> Self::Iterator;
13+
}
14+
15+
trait Bar {
16+
fn bar(&self);
17+
}
18+
19+
fn x<T: Foo>() {
20+
T::iter().next().unwrap().bar();
21+
}
22+
23+
fn main() {}

0 commit comments

Comments
 (0)