Skip to content

Commit 11c141a

Browse files
committed
abi layout: WIP
1 parent 257377c commit 11c141a

File tree

3 files changed

+146
-79
lines changed

3 files changed

+146
-79
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 141 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
55
use crate::codegen_cx::CodegenCx;
66
use crate::spirv_type::SpirvType;
7+
use crate::symbols::Symbols;
78
use itertools::Itertools;
89
use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
910
use rustc_abi::{AbiAlign, ExternAbi as Abi};
@@ -16,10 +17,10 @@ use rustc_errors::ErrorGuaranteed;
1617
use rustc_hashes::Hash64;
1718
use rustc_index::Idx;
1819
use rustc_middle::query::Providers;
19-
use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout};
20+
use rustc_middle::ty::layout::{FnAbiOf, LayoutError, LayoutOf, TyAndLayout};
2021
use rustc_middle::ty::{
21-
self, AdtDef, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, IntTy, PolyFnSig, Ty,
22-
TyCtxt, TyKind, UintTy,
22+
self, AdtDef, Const, CoroutineArgs, CoroutineArgsExt as _, FloatTy, GenericArgs, IntTy,
23+
PolyFnSig, Ty, TyCtxt, TyKind, TypingEnv, UintTy,
2324
};
2425
use rustc_middle::ty::{GenericArgsRef, ScalarInt};
2526
use rustc_middle::{bug, span_bug};
@@ -169,85 +170,14 @@ pub(crate) fn provide(providers: &mut Providers) {
169170
fn layout_of<'tcx>(
170171
tcx: TyCtxt<'tcx>,
171172
key: ty::PseudoCanonicalInput<'tcx, Ty<'tcx>>,
172-
) -> Result<TyAndLayout<'tcx>, &'tcx ty::layout::LayoutError<'tcx>> {
173+
) -> Result<TyAndLayout<'tcx>, &'tcx LayoutError<'tcx>> {
173174
// HACK(eddyb) to special-case any types at all, they must be normalized,
174175
// but when normalization would be needed, `layout_of`'s default provider
175176
// recurses (supposedly for caching reasons), i.e. its calls `layout_of`
176177
// w/ the normalized type in input, which once again reaches this hook,
177178
// without ever needing any explicit normalization here.
178-
let ty = key.value;
179-
180-
// HACK(eddyb) bypassing upstream `#[repr(simd)]` changes (see also
181-
// the later comment above `check_well_formed`, for more details).
182-
let reimplement_old_style_repr_simd: Option<(&AdtDef<'tcx>, Ty<'tcx>, u64)> = match ty
183-
.kind()
184-
{
185-
ty::Adt(def, args) if def.repr().simd() && !def.repr().packed() && def.is_struct() => {
186-
Some(def.non_enum_variant()).and_then(|v| {
187-
let (count, e_ty) = v
188-
.fields
189-
.iter()
190-
.map(|f| f.ty(tcx, args))
191-
.dedup_with_count()
192-
.exactly_one()
193-
.ok()?;
194-
let e_len = u64::try_from(count).ok().filter(|&e_len| e_len > 1)?;
195-
Some((def, e_ty, e_len))
196-
})
197-
}
198-
_ => None,
199-
};
200-
201-
// HACK(eddyb) tweaked copy of the old upstream logic for `#[repr(simd)]`:
202-
// https://github.com/rust-lang/rust/blob/1.86.0/compiler/rustc_ty_utils/src/layout.rs#L464-L590
203-
if let Some((adt_def, e_ty, e_len)) = reimplement_old_style_repr_simd {
204-
let cx = rustc_middle::ty::layout::LayoutCx::new(
205-
tcx,
206-
key.typing_env.with_post_analysis_normalized(tcx),
207-
);
208-
209-
// Compute the ABI of the element type:
210-
let e_ly: TyAndLayout<'_> = cx.layout_of(e_ty)?;
211-
let BackendRepr::Scalar(e_repr) = e_ly.backend_repr else {
212-
// This error isn't caught in typeck, e.g., if
213-
// the element type of the vector is generic.
214-
tcx.dcx().span_fatal(
215-
tcx.def_span(adt_def.did()),
216-
format!(
217-
"SIMD type `{ty}` with a non-primitive-scalar \
218-
(integer/float/pointer) element type `{}`",
219-
e_ly.ty
220-
),
221-
);
222-
};
223-
224-
// Compute the size and alignment of the vector:
225-
let size = e_ly.size.checked_mul(e_len, &cx).unwrap();
226-
let align = adt_def.repr().align.unwrap_or(e_ly.align.abi);
227-
let size = size.align_to(align);
228-
229-
let layout = tcx.mk_layout(LayoutData {
230-
variants: Variants::Single {
231-
index: rustc_abi::FIRST_VARIANT,
232-
},
233-
fields: FieldsShape::Array {
234-
stride: e_ly.size,
235-
count: e_len,
236-
},
237-
backend_repr: BackendRepr::SimdVector {
238-
element: e_repr,
239-
count: e_len,
240-
},
241-
largest_niche: e_ly.largest_niche,
242-
uninhabited: false,
243-
size,
244-
align: AbiAlign::new(align),
245-
max_repr_align: None,
246-
unadjusted_abi_align: align,
247-
randomization_seed: e_ly.randomization_seed.wrapping_add(Hash64::new(e_len)),
248-
});
249-
250-
return Ok(TyAndLayout { ty, layout });
179+
if let Some(layout) = layout_of_spirv_attr_special(tcx, key)? {
180+
return Ok(layout);
251181
}
252182

253183
let TyAndLayout { ty, mut layout } =
@@ -276,6 +206,135 @@ pub(crate) fn provide(providers: &mut Providers) {
276206
Ok(TyAndLayout { ty, layout })
277207
}
278208

209+
fn layout_of_spirv_attr_special<'tcx>(
210+
tcx: TyCtxt<'tcx>,
211+
key: ty::PseudoCanonicalInput<'tcx, Ty<'tcx>>,
212+
) -> Result<Option<TyAndLayout<'tcx>>, &'tcx LayoutError<'tcx>> {
213+
let ty::PseudoCanonicalInput {
214+
typing_env,
215+
value: ty,
216+
} = key;
217+
218+
match ty.kind() {
219+
ty::Adt(def, args) => {
220+
let def: &AdtDef<'tcx> = def;
221+
let args: &'tcx GenericArgs<'tcx> = args;
222+
let attrs = AggregatedSpirvAttributes::parse(
223+
tcx,
224+
&Symbols::get(),
225+
tcx.get_all_attrs(def.did()),
226+
);
227+
228+
// add spirv-attr special layouts here
229+
if let Some(layout) =
230+
layout_of_spirv_vector(tcx, typing_env, ty, def, args, &attrs)?
231+
{
232+
return Ok(Some(layout));
233+
}
234+
}
235+
_ => {}
236+
}
237+
Ok(None)
238+
}
239+
240+
fn layout_of_spirv_vector<'tcx>(
241+
tcx: TyCtxt<'tcx>,
242+
typing_env: TypingEnv<'tcx>,
243+
ty: Ty<'tcx>,
244+
def: &AdtDef<'tcx>,
245+
args: &'tcx GenericArgs<'tcx>,
246+
attrs: &AggregatedSpirvAttributes,
247+
) -> Result<Option<TyAndLayout<'tcx>>, &'tcx LayoutError<'tcx>> {
248+
let layout_err = |msg| {
249+
&*tcx.arena.alloc(LayoutError::ReferencesError(
250+
tcx.dcx().span_err(tcx.def_span(def.did()), msg),
251+
))
252+
};
253+
254+
let has_spirv_vector_attr = attrs
255+
.intrinsic_type
256+
.as_ref()
257+
.map_or(false, |attr| matches!(attr.value, IntrinsicType::Vector));
258+
let has_repr_simd = def.repr().simd() && !def.repr().packed();
259+
if !has_spirv_vector_attr && !has_repr_simd {
260+
return Ok(None);
261+
}
262+
263+
let elements = def
264+
.non_enum_variant()
265+
.fields
266+
.iter()
267+
.map(|f| f.ty(tcx, args))
268+
.dedup_with_count()
269+
.exactly_one()
270+
.ok()
271+
.and_then(|(count, e_ty)| {
272+
u64::try_from(count)
273+
.ok()
274+
.filter(|&e_len| e_len >= 2)
275+
.map(|e_len| (e_len, e_ty))
276+
});
277+
let (e_len, e_ty) = match elements {
278+
None => {
279+
return if has_repr_simd {
280+
// core SIMD struct, not glam vector, don't do anything special
281+
Ok(None)
282+
} else {
283+
Err(layout_err(format!(
284+
"spirv vector type `{ty}` must have at least 2 elements of a single element"
285+
)))
286+
};
287+
}
288+
Some(len) => len,
289+
};
290+
if !def.is_struct() {
291+
return Err(layout_err(format!(
292+
"spirv vector type `{ty}` must be a struct"
293+
)));
294+
}
295+
296+
let lcx = ty::layout::LayoutCx::new(tcx, typing_env.with_post_analysis_normalized(tcx));
297+
298+
// Compute the ABI of the element type:
299+
let e_ly: TyAndLayout<'_> = lcx.layout_of(e_ty)?;
300+
let BackendRepr::Scalar(e_repr) = e_ly.backend_repr else {
301+
// This error isn't caught in typeck, e.g., if
302+
// the element type of the vector is generic.
303+
return Err(layout_err(format!(
304+
"spirv vector type `{ty}` must have a non-primitive-scalar (integer/float/pointer) element type, got `{}`",
305+
e_ly.ty
306+
)));
307+
};
308+
309+
// Compute the size and alignment of the vector:
310+
let size = e_ly.size.checked_mul(e_len, &lcx).unwrap();
311+
let align = def.repr().align.unwrap_or(e_ly.align.abi);
312+
let size = size.align_to(align);
313+
314+
let layout = tcx.mk_layout(LayoutData {
315+
variants: Variants::Single {
316+
index: rustc_abi::FIRST_VARIANT,
317+
},
318+
fields: FieldsShape::Array {
319+
stride: e_ly.size,
320+
count: e_len,
321+
},
322+
backend_repr: BackendRepr::SimdVector {
323+
element: e_repr,
324+
count: e_len,
325+
},
326+
largest_niche: e_ly.largest_niche,
327+
uninhabited: false,
328+
size,
329+
align: AbiAlign::new(align),
330+
max_repr_align: None,
331+
unadjusted_abi_align: align,
332+
randomization_seed: e_ly.randomization_seed.wrapping_add(Hash64::new(e_len)),
333+
});
334+
335+
Ok(Some(TyAndLayout { ty, layout }))
336+
}
337+
279338
// HACK(eddyb) work around https://github.com/rust-lang/rust/pull/129403
280339
// banning "struct-style" `#[repr(simd)]` (in favor of "array-newtype-style"),
281340
// by simply bypassing "type definition WF checks" for affected types, which:
@@ -778,9 +837,9 @@ fn dig_scalar_pointee<'tcx>(
778837
match pointee {
779838
Some(old_pointee) if old_pointee != new_pointee => {
780839
cx.tcx.dcx().fatal(format!(
781-
"dig_scalar_pointee: unsupported Pointer with different \
840+
"dig_scalar_pointee: unsupported Pointer with different \
782841
pointee types ({old_pointee:?} vs {new_pointee:?}) at offset {offset:?} in {layout:#?}"
783-
));
842+
));
784843
}
785844
_ => pointee = Some(new_pointee),
786845
}
@@ -1265,5 +1324,8 @@ fn trans_intrinsic_type<'tcx>(
12651324
}
12661325
.def(span, cx))
12671326
}
1327+
IntrinsicType::Vector => {
1328+
todo!()
1329+
}
12681330
}
12691331
}

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pub enum IntrinsicType {
6565
RuntimeArray,
6666
TypedBuffer,
6767
Matrix,
68+
Vector,
6869
}
6970

7071
#[derive(Copy, Clone, Debug, PartialEq, Eq)]

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ impl Symbols {
373373
"matrix",
374374
SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
375375
),
376+
(
377+
"vector",
378+
SpirvAttribute::IntrinsicType(IntrinsicType::Vector),
379+
),
376380
("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic),
377381
(
378382
"buffer_store_intrinsic",

0 commit comments

Comments
 (0)