|
3 | 3 |
|
4 | 4 | use crate::attr::{AggregatedSpirvAttributes, IntrinsicType}; |
5 | 5 | use crate::codegen_cx::CodegenCx; |
| 6 | +use crate::maybe_pqp_cg_ssa::traits::ConstCodegenMethods as _; |
6 | 7 | use crate::spirv_type::SpirvType; |
7 | 8 | use itertools::Itertools; |
8 | 9 | use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word}; |
@@ -807,6 +808,48 @@ fn trans_intrinsic_type<'tcx>( |
807 | 808 | args: GenericArgsRef<'tcx>, |
808 | 809 | intrinsic_type_attr: IntrinsicType, |
809 | 810 | ) -> Result<Word, ErrorGuaranteed> { |
| 811 | + trait FromScalarInt: Sized { |
| 812 | + fn from_scalar_int(n: ScalarInt) -> Option<Self>; |
| 813 | + } |
| 814 | + |
| 815 | + impl FromScalarInt for u32 { |
| 816 | + fn from_scalar_int(n: ScalarInt) -> Option<Self> { |
| 817 | + Some(n.try_to_bits(Size::from_bits(32)).ok()?.try_into().unwrap()) |
| 818 | + } |
| 819 | + } |
| 820 | + |
| 821 | + impl FromScalarInt for Dim { |
| 822 | + fn from_scalar_int(n: ScalarInt) -> Option<Self> { |
| 823 | + Dim::from_u32(u32::from_scalar_int(n)?) |
| 824 | + } |
| 825 | + } |
| 826 | + |
| 827 | + impl FromScalarInt for ImageFormat { |
| 828 | + fn from_scalar_int(n: ScalarInt) -> Option<Self> { |
| 829 | + ImageFormat::from_u32(u32::from_scalar_int(n)?) |
| 830 | + } |
| 831 | + } |
| 832 | + |
| 833 | + fn const_int_value<'tcx, P: FromScalarInt>( |
| 834 | + cx: &CodegenCx<'tcx>, |
| 835 | + const_: Const<'tcx>, |
| 836 | + ) -> Result<P, ErrorGuaranteed> { |
| 837 | + let ty::Value { |
| 838 | + ty: const_ty, |
| 839 | + valtree: const_val, |
| 840 | + } = const_.to_value(); |
| 841 | + assert!(const_ty.is_integral()); |
| 842 | + const_val |
| 843 | + .try_to_scalar() |
| 844 | + .and_then(|scalar| scalar.try_to_scalar_int().ok()) |
| 845 | + .and_then(P::from_scalar_int) |
| 846 | + .ok_or_else(|| { |
| 847 | + cx.tcx |
| 848 | + .dcx() |
| 849 | + .err(format!("invalid value for const generic: {const_}")) |
| 850 | + }) |
| 851 | + } |
| 852 | + |
810 | 853 | match intrinsic_type_attr { |
811 | 854 | IntrinsicType::GenericImageType => { |
812 | 855 | // see SpirvType::sizeof |
@@ -870,48 +913,6 @@ fn trans_intrinsic_type<'tcx>( |
870 | 913 | // let image_format: spirv::ImageFormat = |
871 | 914 | // type_from_variant_discriminant(cx, args.const_at(6)); |
872 | 915 |
|
873 | | - trait FromScalarInt: Sized { |
874 | | - fn from_scalar_int(n: ScalarInt) -> Option<Self>; |
875 | | - } |
876 | | - |
877 | | - impl FromScalarInt for u32 { |
878 | | - fn from_scalar_int(n: ScalarInt) -> Option<Self> { |
879 | | - Some(n.try_to_bits(Size::from_bits(32)).ok()?.try_into().unwrap()) |
880 | | - } |
881 | | - } |
882 | | - |
883 | | - impl FromScalarInt for Dim { |
884 | | - fn from_scalar_int(n: ScalarInt) -> Option<Self> { |
885 | | - Dim::from_u32(u32::from_scalar_int(n)?) |
886 | | - } |
887 | | - } |
888 | | - |
889 | | - impl FromScalarInt for ImageFormat { |
890 | | - fn from_scalar_int(n: ScalarInt) -> Option<Self> { |
891 | | - ImageFormat::from_u32(u32::from_scalar_int(n)?) |
892 | | - } |
893 | | - } |
894 | | - |
895 | | - fn const_int_value<'tcx, P: FromScalarInt>( |
896 | | - cx: &CodegenCx<'tcx>, |
897 | | - const_: Const<'tcx>, |
898 | | - ) -> Result<P, ErrorGuaranteed> { |
899 | | - let ty::Value { |
900 | | - ty: const_ty, |
901 | | - valtree: const_val, |
902 | | - } = const_.to_value(); |
903 | | - assert!(const_ty.is_integral()); |
904 | | - const_val |
905 | | - .try_to_scalar() |
906 | | - .and_then(|scalar| scalar.try_to_scalar_int().ok()) |
907 | | - .and_then(P::from_scalar_int) |
908 | | - .ok_or_else(|| { |
909 | | - cx.tcx |
910 | | - .dcx() |
911 | | - .err(format!("invalid value for Image const generic: {const_}")) |
912 | | - }) |
913 | | - } |
914 | | - |
915 | 916 | let dim = const_int_value(cx, args.const_at(1))?; |
916 | 917 | let depth = const_int_value(cx, args.const_at(2))?; |
917 | 918 | let arrayed = const_int_value(cx, args.const_at(3))?; |
@@ -941,6 +942,54 @@ fn trans_intrinsic_type<'tcx>( |
941 | 942 | Ok(SpirvType::AccelerationStructureKhr.def(span, cx)) |
942 | 943 | } |
943 | 944 | IntrinsicType::RayQueryKhr => Ok(SpirvType::RayQueryKhr.def(span, cx)), |
| 945 | + IntrinsicType::CooperativeMatrixKhr => { |
| 946 | + if ty.size != Size::from_bytes(4) { |
| 947 | + return Err(cx.tcx.dcx().err("cooperative_matrix type must have size 4")); |
| 948 | + } |
| 949 | + |
| 950 | + // Generic arg 0: component type T |
| 951 | + let component_type = match args.type_at(0).kind() { |
| 952 | + TyKind::Float(FloatTy::F32) => SpirvType::Float(32).def(span, cx), |
| 953 | + TyKind::Float(FloatTy::F64) => SpirvType::Float(64).def(span, cx), |
| 954 | + TyKind::Int(IntTy::I8) => SpirvType::Integer(8, true).def(span, cx), |
| 955 | + TyKind::Int(IntTy::I16) => SpirvType::Integer(16, true).def(span, cx), |
| 956 | + TyKind::Int(IntTy::I32) => SpirvType::Integer(32, true).def(span, cx), |
| 957 | + TyKind::Int(IntTy::I64) => SpirvType::Integer(64, true).def(span, cx), |
| 958 | + TyKind::Uint(UintTy::U8) => SpirvType::Integer(8, false).def(span, cx), |
| 959 | + TyKind::Uint(UintTy::U16) => SpirvType::Integer(16, false).def(span, cx), |
| 960 | + TyKind::Uint(UintTy::U32) => SpirvType::Integer(32, false).def(span, cx), |
| 961 | + TyKind::Uint(UintTy::U64) => SpirvType::Integer(64, false).def(span, cx), |
| 962 | + _ => { |
| 963 | + return Err(cx.tcx.dcx().span_err( |
| 964 | + span, |
| 965 | + "unsupported component type for #[spirv(cooperative_matrix)]: \ |
| 966 | + must be f32, f64, i8, i16, i32, i64, u8, u16, u32, or u64", |
| 967 | + )); |
| 968 | + } |
| 969 | + }; |
| 970 | + |
| 971 | + // Const generic 1: USE (MatrixA=0, MatrixB=1, MatrixAccumulator=2) |
| 972 | + // Const generic 2: ROWS |
| 973 | + // Const generic 3: COLS |
| 974 | + let use_val: u32 = const_int_value(cx, args.const_at(1))?; |
| 975 | + let rows_val: u32 = const_int_value(cx, args.const_at(2))?; |
| 976 | + let cols_val: u32 = const_int_value(cx, args.const_at(3))?; |
| 977 | + |
| 978 | + // Scope: Subgroup = 3 |
| 979 | + let scope = cx.const_u32(3).def_cx(cx); |
| 980 | + let rows = cx.const_u32(rows_val).def_cx(cx); |
| 981 | + let columns = cx.const_u32(cols_val).def_cx(cx); |
| 982 | + let use_ = cx.const_u32(use_val).def_cx(cx); |
| 983 | + |
| 984 | + Ok(SpirvType::CooperativeMatrixKhr { |
| 985 | + component_type, |
| 986 | + scope, |
| 987 | + rows, |
| 988 | + columns, |
| 989 | + use_, |
| 990 | + } |
| 991 | + .def(span, cx)) |
| 992 | + } |
944 | 993 | IntrinsicType::SampledImage => { |
945 | 994 | // see SpirvType::sizeof |
946 | 995 | if ty.size != Size::from_bytes(4) { |
|
0 commit comments