Skip to content

Commit 1ae2d38

Browse files
committed
fixed fields chain, added new test cases
1 parent cde9fa0 commit 1ae2d38

File tree

3 files changed

+82
-30
lines changed

3 files changed

+82
-30
lines changed

compiler/rustc_mir_build/src/check_unsafety.rs

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ struct UnsafetyVisitor<'a, 'tcx> {
4545
/// Flag to ensure that we only suggest wrapping the entire function body in
4646
/// an unsafe block once.
4747
suggest_unsafe_block: bool,
48+
/// Controls how union field accesses are checked
49+
union_field_access_mode: UnionFieldAccessMode,
4850
}
4951

5052
impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
@@ -218,6 +220,7 @@ impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
218220
inside_adt: false,
219221
warnings: self.warnings,
220222
suggest_unsafe_block: self.suggest_unsafe_block,
223+
union_field_access_mode: UnionFieldAccessMode::Normal,
221224
};
222225
// params in THIR may be unsafe, e.g. a union pattern.
223226
for param in &inner_thir.params {
@@ -658,18 +661,25 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
658661
} else if adt_def.is_union() {
659662
// Check if this field access is part of a raw borrow operation
660663
// If so, we've already handled it above and shouldn't reach here
661-
if let Some(assigned_ty) = self.assignment_info {
662-
if assigned_ty.needs_drop(self.tcx, self.typing_env) {
663-
// This would be unsafe, but should be outright impossible since we
664-
// reject such unions.
665-
assert!(
666-
self.tcx.dcx().has_errors().is_some(),
667-
"union fields that need dropping should be impossible: {assigned_ty}"
668-
);
664+
match self.union_field_access_mode {
665+
UnionFieldAccessMode::SuppressUnionFieldAccessError => {
666+
// Suppress AccessToUnionField error for union fields chains
667+
}
668+
UnionFieldAccessMode::Normal => {
669+
if let Some(assigned_ty) = self.assignment_info {
670+
if assigned_ty.needs_drop(self.tcx, self.typing_env) {
671+
// This would be unsafe, but should be outright impossible since we
672+
// reject such unions.
673+
assert!(
674+
self.tcx.dcx().has_errors().is_some(),
675+
"union fields that need dropping should be impossible: {assigned_ty}"
676+
);
677+
}
678+
} else {
679+
// Only require unsafe if this is not a raw borrow operation
680+
self.requires_unsafe(expr.span, AccessToUnionField);
681+
}
669682
}
670-
} else {
671-
// Only require unsafe if this is not a raw borrow operation
672-
self.requires_unsafe(expr.span, AccessToUnionField);
673683
}
674684
}
675685
}
@@ -728,7 +738,7 @@ impl<'a, 'tcx> UnsafetyVisitor<'a, 'tcx> {
728738
match self.thir[expr_id].kind {
729739
ExprKind::Field { lhs, .. } => {
730740
let lhs = &self.thir[lhs];
731-
if let ty::Adt(adt_def, _) = lhs.ty.kind() { adt_def.is_union() } else { false }
741+
matches!(lhs.ty.kind(), ty::Adt(adt_def, _) if adt_def.is_union())
732742
}
733743
_ => false,
734744
}
@@ -737,28 +747,28 @@ impl<'a, 'tcx> UnsafetyVisitor<'a, 'tcx> {
737747
/// Visit a union field access in the context of a raw borrow operation
738748
/// This ensures we still check safety of nested operations while allowing
739749
/// the raw pointer creation itself
740-
fn visit_union_field_for_raw_borrow(&mut self, expr_id: ExprId) {
741-
match self.thir[expr_id].kind {
742-
ExprKind::Field { lhs, variant_index, name } => {
743-
let lhs_expr = &self.thir[lhs];
744-
if let ty::Adt(adt_def, _) = lhs_expr.ty.kind() {
745-
// Check for unsafe fields but skip the union access check
746-
if adt_def.variant(variant_index).fields[name].safety.is_unsafe() {
747-
self.requires_unsafe(self.thir[expr_id].span, UseOfUnsafeField);
748-
}
749-
// For unions, we don't require unsafe for raw pointer creation
750-
// But we still need to check the LHS for safety
751-
self.visit_expr(lhs_expr);
752-
} else {
753-
// Not a union, use normal visiting
754-
visit::walk_expr(self, &self.thir[expr_id]);
750+
fn visit_union_field_for_raw_borrow(&mut self, mut expr_id: ExprId) {
751+
let prev = self.union_field_access_mode;
752+
self.union_field_access_mode = UnionFieldAccessMode::SuppressUnionFieldAccessError;
753+
// Walk through the chain of union field accesses using while let
754+
while let ExprKind::Field { lhs, variant_index, name } = self.thir[expr_id].kind {
755+
let lhs_expr = &self.thir[lhs];
756+
if let ty::Adt(adt_def, _) = lhs_expr.ty.kind() {
757+
// Check for unsafe fields but skip the union access check
758+
if adt_def.variant(variant_index).fields[name].safety.is_unsafe() {
759+
self.requires_unsafe(self.thir[expr_id].span, UseOfUnsafeField);
755760
}
756-
}
757-
_ => {
758-
// Not a field access, use normal visiting
761+
// If the LHS is also a union field access, keep walking
762+
expr_id = lhs;
763+
} else {
764+
// Not a union, use normal visiting
759765
visit::walk_expr(self, &self.thir[expr_id]);
766+
return;
760767
}
761768
}
769+
// Visit the base expression for any nested safety checks
770+
self.visit_expr(&self.thir[expr_id]);
771+
self.union_field_access_mode = prev;
762772
}
763773
}
764774

@@ -770,6 +780,13 @@ enum SafetyContext {
770780
UnsafeBlock { span: Span, hir_id: HirId, used: bool, nested_used_blocks: Vec<NestedUsedBlock> },
771781
}
772782

783+
/// Controls how union field accesses are checked
784+
#[derive(Clone, Copy)]
785+
enum UnionFieldAccessMode {
786+
Normal,
787+
SuppressUnionFieldAccessError,
788+
}
789+
773790
#[derive(Clone, Copy)]
774791
struct NestedUsedBlock {
775792
hir_id: HirId,
@@ -1244,6 +1261,7 @@ pub(crate) fn check_unsafety(tcx: TyCtxt<'_>, def: LocalDefId) {
12441261
inside_adt: false,
12451262
warnings: &mut warnings,
12461263
suggest_unsafe_block: true,
1264+
union_field_access_mode: UnionFieldAccessMode::Normal,
12471265
};
12481266
// params in THIR may be unsafe, e.g. a union pattern.
12491267
for param in &thir.params {

main

451 KB
Binary file not shown.

tests/ui/union/union-unsafe.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,38 @@ fn main() {
9696
let mut u3 = U3 { a: ManuallyDrop::new(String::from("old")) }; // OK
9797
u3.a = ManuallyDrop::new(String::from("new")); // OK (assignment does not drop)
9898
*u3.a = String::from("new"); //~ ERROR access to union field is unsafe
99+
100+
let mut unions = [U1 { a: 1 }, U1 { a: 2 }];
101+
102+
// Array indexing + union field raw borrow - should be OK
103+
let ptr = &raw mut unions[0].a; // OK
104+
let ptr2 = &raw const unions[1].a; // OK
105+
106+
// Test for union fields chain, this should be allowed
107+
#[derive(Copy, Clone)]
108+
union Inner {
109+
a: u8,
110+
}
111+
112+
union MoreInner {
113+
moreinner: ManuallyDrop<Inner>,
114+
}
115+
116+
union LessOuter {
117+
lessouter: ManuallyDrop<MoreInner>,
118+
}
119+
120+
union Outer {
121+
outer: ManuallyDrop<LessOuter>,
122+
}
123+
124+
let super_outer = Outer {
125+
outer: ManuallyDrop::new(LessOuter {
126+
lessouter: ManuallyDrop::new(MoreInner {
127+
moreinner: ManuallyDrop::new(Inner { a: 42 }),
128+
}),
129+
}),
130+
};
131+
132+
let ptr = &raw const super_outer.outer.lessouter.moreinner.a;
99133
}

0 commit comments

Comments
 (0)