Skip to content

Commit d0db6bc

Browse files
committed
Add or-patterns to pattern types
1 parent e578e31 commit d0db6bc

File tree

35 files changed

+477
-10
lines changed

35 files changed

+477
-10
lines changed

compiler/rustc_ast/src/ast.rs

+2
Original file line numberDiff line numberDiff line change
@@ -2364,6 +2364,8 @@ pub enum TyPatKind {
23642364
/// A range pattern (e.g., `1...2`, `1..2`, `1..`, `..2`, `1..=2`, `..=2`).
23652365
Range(Option<P<AnonConst>>, Option<P<AnonConst>>, Spanned<RangeEnd>),
23662366

2367+
Or(ThinVec<P<TyPat>>),
2368+
23672369
/// A `!null` pattern for raw pointers.
23682370
NotNull,
23692371

compiler/rustc_ast/src/mut_visit.rs

+1
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ pub fn walk_ty_pat<T: MutVisitor>(vis: &mut T, ty: &mut P<TyPat>) {
609609
visit_opt(start, |c| vis.visit_anon_const(c));
610610
visit_opt(end, |c| vis.visit_anon_const(c));
611611
}
612+
TyPatKind::Or(variants) => visit_thin_vec(variants, |p| vis.visit_ty_pat(p)),
612613
TyPatKind::NotNull | TyPatKind::Err(_) => {}
613614
}
614615
visit_lazy_tts(vis, tokens);

compiler/rustc_ast/src/visit.rs

+1
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ pub fn walk_ty_pat<'a, V: Visitor<'a>>(visitor: &mut V, tp: &'a TyPat) -> V::Res
565565
visit_opt!(visitor, visit_anon_const, start);
566566
visit_opt!(visitor, visit_anon_const, end);
567567
}
568+
TyPatKind::Or(variants) => walk_list!(visitor, visit_ty_pat, variants),
568569
TyPatKind::NotNull | TyPatKind::Err(_) => {}
569570
}
570571
V::Result::output()

compiler/rustc_ast_lowering/src/pat.rs

+5
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,11 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
465465
)
466466
}),
467467
),
468+
TyPatKind::Or(variants) => {
469+
hir::TyPatKind::Or(self.arena.alloc_from_iter(
470+
variants.iter().map(|pat| self.lower_ty_pat_mut(pat, base_type)),
471+
))
472+
}
468473
TyPatKind::NotNull => hir::TyPatKind::NotNull,
469474
TyPatKind::Err(guar) => hir::TyPatKind::Err(*guar),
470475
};

compiler/rustc_ast_pretty/src/pprust/state.rs

+11
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,17 @@ impl<'a> State<'a> {
11671167
self.print_expr_anon_const(end, &[]);
11681168
}
11691169
}
1170+
rustc_ast::TyPatKind::Or(variants) => {
1171+
let mut first = true;
1172+
for pat in variants {
1173+
if first {
1174+
first = false
1175+
} else {
1176+
self.word(" | ");
1177+
}
1178+
self.print_ty_pat(pat);
1179+
}
1180+
}
11701181
rustc_ast::TyPatKind::NotNull => self.word("!null"),
11711182
rustc_ast::TyPatKind::Err(_) => {
11721183
self.popen();

compiler/rustc_builtin_macros/src/pattern_type.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use rustc_ast::{AnonConst, DUMMY_NODE_ID, Ty, TyPat, TyPatKind, ast, token};
44
use rustc_errors::PResult;
55
use rustc_expand::base::{self, DummyResult, ExpandResult, ExtCtxt, MacroExpanderResult};
66
use rustc_parse::exp;
7+
use rustc_parse::parser::{CommaRecoveryMode, RecoverColon, RecoverComma};
78
use rustc_span::Span;
89

910
pub(crate) fn expand<'cx>(
@@ -33,7 +34,17 @@ fn parse_pat_ty<'a>(cx: &mut ExtCtxt<'a>, stream: TokenStream) -> PResult<'a, (P
3334
let span = start.to(parser.token.span);
3435
ty_pat(TyPatKind::NotNull, span)
3536
} else {
36-
pat_to_ty_pat(cx, parser.parse_pat_no_top_alt(None, None)?.into_inner())
37+
pat_to_ty_pat(
38+
cx,
39+
parser
40+
.parse_pat_no_top_guard(
41+
None,
42+
RecoverComma::No,
43+
RecoverColon::No,
44+
CommaRecoveryMode::EitherTupleOrPipe,
45+
)?
46+
.into_inner(),
47+
)
3748
};
3849

3950
if parser.token != token::Eof {
@@ -53,6 +64,9 @@ fn pat_to_ty_pat(cx: &mut ExtCtxt<'_>, pat: ast::Pat) -> P<TyPat> {
5364
end.map(|value| P(AnonConst { id: DUMMY_NODE_ID, value })),
5465
include_end,
5566
),
67+
ast::PatKind::Or(variants) => TyPatKind::Or(
68+
variants.into_iter().map(|pat| pat_to_ty_pat(cx, pat.into_inner())).collect(),
69+
),
5670
ast::PatKind::Err(guar) => TyPatKind::Err(guar),
5771
_ => TyPatKind::Err(cx.dcx().span_err(pat.span, "pattern not supported in pattern types")),
5872
};

compiler/rustc_const_eval/src/interpret/intrinsics.rs

+3
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ pub(crate) fn eval_nullary_intrinsic<'tcx>(
7070
ty::Pat(_, pat) => match **pat {
7171
ty::PatternKind::Range { .. } => ConstValue::from_target_usize(0u64, &tcx),
7272
ty::PatternKind::NotNull => ConstValue::from_target_usize(0_u64, &tcx),
73+
// FIXME(pattern_types): make this report the number of distinct variants used in the
74+
// or pattern in case the base type is an enum.
75+
ty::PatternKind::Or(_) => ConstValue::from_target_usize(0_u64, &tcx),
7376
},
7477
ty::Bound(_, _) => bug!("bound ty during ctfe"),
7578
ty::Bool

compiler/rustc_const_eval/src/interpret/validity.rs

+4
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,10 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValueVisitor<'tcx, M> for ValidityVisitor<'rt,
12491249
// handled fully by `visit_scalar` (called below).
12501250
ty::PatternKind::Range { .. } => {},
12511251
ty::PatternKind::NotNull => {},
1252+
1253+
// FIXME(pattern_types): check that the value is covered by one of the variants.
1254+
// The layout may pessimistically cover actually illegal ranges.
1255+
ty::PatternKind::Or(_patterns) => {}
12521256
}
12531257
}
12541258
_ => {

compiler/rustc_hir/src/hir.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1670,6 +1670,9 @@ pub enum TyPatKind<'hir> {
16701670
/// A range pattern (e.g., `1..=2` or `1..2`).
16711671
Range(&'hir ConstArg<'hir>, &'hir ConstArg<'hir>),
16721672

1673+
/// A list of patterns where only one needs to be satisfied
1674+
Or(&'hir [TyPat<'hir>]),
1675+
16731676
/// A pattern that excludes null pointers
16741677
NotNull,
16751678

compiler/rustc_hir/src/intravisit.rs

+1
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ pub fn walk_ty_pat<'v, V: Visitor<'v>>(visitor: &mut V, pattern: &'v TyPat<'v>)
693693
try_visit!(visitor.visit_const_arg_unambig(lower_bound));
694694
try_visit!(visitor.visit_const_arg_unambig(upper_bound));
695695
}
696+
TyPatKind::Or(patterns) => walk_list!(visitor, visit_pattern_type_pattern, patterns),
696697
TyPatKind::NotNull | TyPatKind::Err(_) => (),
697698
}
698699
V::Result::output()

compiler/rustc_hir_analysis/src/collect/type_of.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,12 @@ fn const_arg_anon_type_of<'tcx>(icx: &ItemCtxt<'tcx>, arg_hir_id: HirId, span: S
9393
}
9494

9595
Node::TyPat(pat) => {
96-
let hir::TyKind::Pat(ty, p) = tcx.parent_hir_node(pat.hir_id).expect_ty().kind else {
97-
bug!()
96+
let node = match tcx.parent_hir_node(pat.hir_id) {
97+
// Or patterns can be nested one level deep
98+
Node::TyPat(p) => tcx.parent_hir_node(p.hir_id),
99+
other => other,
98100
};
99-
assert_eq!(p.hir_id, pat.hir_id);
101+
let hir::TyKind::Pat(ty, _) = node.expect_ty().kind else { bug!() };
100102
icx.lower_ty(ty)
101103
}
102104

compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs

+8
Original file line numberDiff line numberDiff line change
@@ -2709,6 +2709,7 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
27092709
ty_span: Span,
27102710
pat: &hir::TyPat<'tcx>,
27112711
) -> Result<ty::PatternKind<'tcx>, ErrorGuaranteed> {
2712+
let tcx = self.tcx();
27122713
match pat.kind {
27132714
hir::TyPatKind::Range(start, end) => {
27142715
match ty.kind() {
@@ -2724,6 +2725,13 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
27242725
.span_delayed_bug(ty_span, "invalid base type for range pattern")),
27252726
}
27262727
}
2728+
hir::TyPatKind::Or(patterns) => {
2729+
self.tcx()
2730+
.mk_patterns_from_iter(patterns.iter().map(|pat| {
2731+
self.lower_pat_ty_pat(ty, ty_span, pat).map(|pat| tcx.mk_pat(pat))
2732+
}))
2733+
.map(ty::PatternKind::Or)
2734+
}
27272735
hir::TyPatKind::NotNull => Ok(ty::PatternKind::NotNull),
27282736
hir::TyPatKind::Err(e) => Err(e),
27292737
}

compiler/rustc_hir_analysis/src/variance/constraints.rs

+5
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,11 @@ impl<'a, 'tcx> ConstraintContext<'a, 'tcx> {
340340
self.add_constraints_from_const(current, start, variance);
341341
self.add_constraints_from_const(current, end, variance);
342342
}
343+
ty::PatternKind::Or(patterns) => {
344+
for pat in patterns {
345+
self.add_constraints_from_pat(current, variance, pat)
346+
}
347+
}
343348
ty::PatternKind::NotNull => {}
344349
}
345350
}

compiler/rustc_hir_pretty/src/lib.rs

+13
Original file line numberDiff line numberDiff line change
@@ -1877,6 +1877,19 @@ impl<'a> State<'a> {
18771877
self.word("..=");
18781878
self.print_const_arg(end);
18791879
}
1880+
TyPatKind::Or(patterns) => {
1881+
self.popen();
1882+
let mut first = true;
1883+
for pat in patterns {
1884+
if first {
1885+
first = false;
1886+
} else {
1887+
self.word(" | ");
1888+
}
1889+
self.print_ty_pat(pat);
1890+
}
1891+
self.pclose();
1892+
}
18801893
TyPatKind::NotNull => {
18811894
self.word_space("not");
18821895
self.word("null");

compiler/rustc_lint/src/types.rs

+10
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,9 @@ fn pat_ty_is_known_nonnull<'tcx>(
901901
// to ensure we aren't wrapping over zero.
902902
start > 0 && end >= start
903903
}
904+
ty::PatternKind::Or(patterns) => {
905+
patterns.iter().all(|pat| pat_ty_is_known_nonnull(tcx, typing_env, pat))
906+
}
904907
ty::PatternKind::NotNull => true,
905908
}
906909
},
@@ -1063,6 +1066,13 @@ fn get_nullable_type_from_pat<'tcx>(
10631066
ty::PatternKind::NotNull | ty::PatternKind::Range { .. } => {
10641067
get_nullable_type(tcx, typing_env, base)
10651068
}
1069+
ty::PatternKind::Or(patterns) => {
1070+
let first = get_nullable_type_from_pat(tcx, typing_env, base, patterns[0])?;
1071+
for &pat in &patterns[1..] {
1072+
assert_eq!(first, get_nullable_type_from_pat(tcx, typing_env, base, pat)?);
1073+
}
1074+
Some(first)
1075+
}
10661076
}
10671077
}
10681078

compiler/rustc_middle/src/ty/codec.rs

+10
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,15 @@ impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> RefDecodable<'tcx, D>
419419
}
420420
}
421421

422+
impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> RefDecodable<'tcx, D> for ty::List<ty::Pattern<'tcx>> {
423+
fn decode(decoder: &mut D) -> &'tcx Self {
424+
let len = decoder.read_usize();
425+
decoder.interner().mk_patterns_from_iter(
426+
(0..len).map::<ty::Pattern<'tcx>, _>(|_| Decodable::decode(decoder)),
427+
)
428+
}
429+
}
430+
422431
impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> RefDecodable<'tcx, D> for ty::List<ty::Const<'tcx>> {
423432
fn decode(decoder: &mut D) -> &'tcx Self {
424433
let len = decoder.read_usize();
@@ -482,6 +491,7 @@ impl_decodable_via_ref! {
482491
&'tcx mir::Body<'tcx>,
483492
&'tcx mir::BorrowCheckResult<'tcx>,
484493
&'tcx ty::List<ty::BoundVariableKind>,
494+
&'tcx ty::List<ty::Pattern<'tcx>>,
485495
&'tcx ty::ListWithCachedTypeInfo<ty::Clause<'tcx>>,
486496
&'tcx ty::List<FieldIdx>,
487497
&'tcx ty::List<(VariantIdx, FieldIdx)>,

compiler/rustc_middle/src/ty/context.rs

+11
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,7 @@ pub struct CtxtInterners<'tcx> {
812812
captures: InternedSet<'tcx, List<&'tcx ty::CapturedPlace<'tcx>>>,
813813
offset_of: InternedSet<'tcx, List<(VariantIdx, FieldIdx)>>,
814814
valtree: InternedSet<'tcx, ty::ValTreeKind<'tcx>>,
815+
patterns: InternedSet<'tcx, List<ty::Pattern<'tcx>>>,
815816
}
816817

817818
impl<'tcx> CtxtInterners<'tcx> {
@@ -848,6 +849,7 @@ impl<'tcx> CtxtInterners<'tcx> {
848849
captures: InternedSet::with_capacity(N),
849850
offset_of: InternedSet::with_capacity(N),
850851
valtree: InternedSet::with_capacity(N),
852+
patterns: InternedSet::with_capacity(N),
851853
}
852854
}
853855

@@ -2594,6 +2596,7 @@ slice_interners!(
25942596
local_def_ids: intern_local_def_ids(LocalDefId),
25952597
captures: intern_captures(&'tcx ty::CapturedPlace<'tcx>),
25962598
offset_of: pub mk_offset_of((VariantIdx, FieldIdx)),
2599+
patterns: pub mk_patterns(Pattern<'tcx>),
25972600
);
25982601

25992602
impl<'tcx> TyCtxt<'tcx> {
@@ -2867,6 +2870,14 @@ impl<'tcx> TyCtxt<'tcx> {
28672870
self.intern_local_def_ids(clauses)
28682871
}
28692872

2873+
pub fn mk_patterns_from_iter<I, T>(self, iter: I) -> T::Output
2874+
where
2875+
I: Iterator<Item = T>,
2876+
T: CollectAndApply<ty::Pattern<'tcx>, &'tcx List<ty::Pattern<'tcx>>>,
2877+
{
2878+
T::collect_and_apply(iter, |xs| self.mk_patterns(xs))
2879+
}
2880+
28702881
pub fn mk_local_def_ids_from_iter<I, T>(self, iter: I) -> T::Output
28712882
where
28722883
I: Iterator<Item = T>,

compiler/rustc_middle/src/ty/flags.rs

+5
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,11 @@ impl FlagComputation {
259259
self.add_const(start);
260260
self.add_const(end);
261261
}
262+
ty::PatternKind::Or(patterns) => {
263+
for pat in patterns {
264+
self.add_pat(pat);
265+
}
266+
}
262267
ty::PatternKind::NotNull => {}
263268
}
264269
}

compiler/rustc_middle/src/ty/pattern.rs

+14
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,19 @@ impl<'tcx> fmt::Debug for PatternKind<'tcx> {
5151

5252
write!(f, "..={end}")
5353
}
54+
PatternKind::Or(patterns) => {
55+
write!(f, "(")?;
56+
let mut first = true;
57+
for pat in patterns {
58+
if first {
59+
first = false
60+
} else {
61+
write!(f, " | ")?;
62+
}
63+
write!(f, "{pat:?}")?;
64+
}
65+
write!(f, ")")
66+
}
5467
PatternKind::NotNull => write!(f, "!null"),
5568
}
5669
}
@@ -60,5 +73,6 @@ impl<'tcx> fmt::Debug for PatternKind<'tcx> {
6073
#[derive(HashStable, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
6174
pub enum PatternKind<'tcx> {
6275
Range { start: ty::Const<'tcx>, end: ty::Const<'tcx> },
76+
Or(&'tcx ty::List<Pattern<'tcx>>),
6377
NotNull,
6478
}

compiler/rustc_middle/src/ty/relate.rs

+12-3
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,19 @@ impl<'tcx> Relate<TyCtxt<'tcx>> for ty::Pattern<'tcx> {
5959
let end = relation.relate(end_a, end_b)?;
6060
Ok(tcx.mk_pat(ty::PatternKind::Range { start, end }))
6161
}
62-
(ty::PatternKind::NotNull, ty::PatternKind::NotNull) => Ok(a),
63-
(ty::PatternKind::NotNull | ty::PatternKind::Range { .. }, _) => {
64-
Err(TypeError::Mismatch)
62+
(&ty::PatternKind::Or(a), &ty::PatternKind::Or(b)) => {
63+
if a.len() != b.len() {
64+
return Err(TypeError::Mismatch);
65+
}
66+
let v = iter::zip(a, b).map(|(a, b)| relation.relate(a, b));
67+
let patterns = tcx.mk_patterns_from_iter(v)?;
68+
Ok(tcx.mk_pat(ty::PatternKind::Or(patterns)))
6569
}
70+
(ty::PatternKind::NotNull, ty::PatternKind::NotNull) => Ok(a),
71+
(
72+
ty::PatternKind::NotNull | ty::PatternKind::Range { .. } | ty::PatternKind::Or(_),
73+
_,
74+
) => Err(TypeError::Mismatch),
6675
}
6776
}
6877
}

compiler/rustc_middle/src/ty/structural_impls.rs

+9
Original file line numberDiff line numberDiff line change
@@ -720,3 +720,12 @@ impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for &'tcx ty::List<PlaceElem<'tcx>> {
720720
ty::util::fold_list(self, folder, |tcx, v| tcx.mk_place_elems(v))
721721
}
722722
}
723+
724+
impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for &'tcx ty::List<ty::Pattern<'tcx>> {
725+
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
726+
self,
727+
folder: &mut F,
728+
) -> Result<Self, F::Error> {
729+
ty::util::fold_list(self, folder, |tcx, v| tcx.mk_patterns(v))
730+
}
731+
}

compiler/rustc_middle/src/ty/walk.rs

+5
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ fn push_pat<'tcx>(stack: &mut SmallVec<[GenericArg<'tcx>; 8]>, pat: ty::Pattern<
217217
stack.push(end.into());
218218
stack.push(start.into());
219219
}
220+
ty::PatternKind::Or(patterns) => {
221+
for pat in patterns {
222+
push_pat(stack, pat)
223+
}
224+
}
220225
ty::PatternKind::NotNull => {}
221226
}
222227
}

compiler/rustc_resolve/src/late.rs

+5
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,11 @@ impl<'ra: 'ast, 'ast, 'tcx> Visitor<'ast> for LateResolutionVisitor<'_, 'ast, 'r
937937
self.resolve_anon_const(end, AnonConstKind::ConstArg(IsRepeatExpr::No));
938938
}
939939
}
940+
TyPatKind::Or(patterns) => {
941+
for pat in patterns {
942+
self.visit_ty_pat(pat)
943+
}
944+
}
940945
TyPatKind::NotNull | TyPatKind::Err(_) => {}
941946
}
942947
}

compiler/rustc_smir/src/rustc_smir/convert/ty.rs

+1
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ impl<'tcx> Stable<'tcx> for ty::Pattern<'tcx> {
411411
end: Some(end.stable(tables)),
412412
include_end: true,
413413
},
414+
ty::PatternKind::Or(_) => todo!(),
414415
ty::PatternKind::NotNull => stable_mir::ty::Pattern::NotNull,
415416
}
416417
}

0 commit comments

Comments
 (0)