Skip to content

[SYCL][Matrix] Fix checked matrix instructions #13287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,32 @@ template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_CooperativeMatrixConstructCheckedINTEL(
const T Value, uint32_t Height, size_t Stride, uint32_t Width,
int32_t CoordX, int32_t CoordY);
__spirv_CooperativeMatrixConstructCheckedINTEL(int32_t CoordX,
int32_t CoordY,
uint32_t Height,
uint32_t Width,
const T Value);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_JointMatrixLoadCheckedINTEL(T *Ptr, std::size_t Stride,
uint32_t Height, uint32_t Width,
int32_t CoordX, int32_t CoordY,
__spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S,
int MemOperand = 0);
__spirv_CooperativeMatrixLoadCheckedINTEL(
T *Ptr, int32_t CoordX, int32_t CoordY, __spv::MatrixLayout Layout = L,
uint32_t Height = 0, uint32_t Width = 0, std::size_t Stride = 0,
int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreCheckedINTEL(
T *Ptr, __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
std::size_t Stride, uint32_t Height, uint32_t Width, int32_t CoordX,
int32_t CoordY, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S,
int MemOperand = 0);
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
T *Ptr, int32_t CoordX, int32_t CoordY,
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
__spv::MatrixLayout Layout = L, uint32_t Height = 0, uint32_t Width = 0,
std::size_t Stride = 0, int MemOperand = 0);

template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
Expand Down
66 changes: 30 additions & 36 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
layout Layout, typename T2>
inline __SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(
Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
const T2 &Value, size_t Stride, size_t Height, size_t Width, size_t CoordX,
const T2 &Value, size_t Height, size_t Width, size_t CoordX,
size_t CoordY) {
#if defined(__SYCL_DEVICE_ONLY__)
using storage_element_type =
Expand All @@ -622,12 +622,10 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(
storage_element_type, T, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
static_cast<storage_element_type>(Value), Stride, Height, Width, CoordX,
CoordY);
CoordX, CoordY, Height, Width, static_cast<storage_element_type>(Value));
#else
std::ignore = Res;
std::ignore = Value;
std::ignore = Stride;
std::ignore = Height;
std::ignore = Width;
std::ignore = CoordX;
Expand All @@ -654,13 +652,12 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
std::ignore = sg;
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
DecorT, S, NumRows, NumCols,
spv_matrix_use_traits<use::accumulator>::value,
spv_matrix_layout_traits<layout::dynamic>::value>(
Ptr, Stride, Height, Width, CoordX, CoordY,
sycl::detail::joint_matrix_layout_to_spv(Layout),
spv_scope_traits<Group>::value);
Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv(Layout),
Height, Width, Stride);
#else
std::ignore = sg;
std::ignore = Res;
Expand Down Expand Up @@ -694,11 +691,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
std::ignore = sg;
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
Ptr, Stride, Height, Width, CoordX, CoordY,
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
Width, Stride);
#else
std::ignore = sg;
std::ignore = Res;
Expand Down Expand Up @@ -727,13 +724,12 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
std::ignore = sg;
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
__spirv_JointMatrixStoreCheckedINTEL<
__spirv_CooperativeMatrixStoreCheckedINTEL<
DecorT, T, NumRows, NumCols,
spv_matrix_use_traits<use::accumulator>::value,
spv_matrix_layout_traits<layout::dynamic>::value>(
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
sycl::detail::joint_matrix_layout_to_spv(Layout),
spv_scope_traits<Group>::value);
Ptr, CoordX, CoordY, Src.spvm,
sycl::detail::joint_matrix_layout_to_spv(Layout), Height, Width, Stride);
#else
std::ignore = sg;
std::ignore = Src;
Expand Down Expand Up @@ -763,11 +759,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
std::ignore = sg;
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
__spirv_JointMatrixStoreCheckedINTEL<DecorT, Tp, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
__spirv_CooperativeMatrixStoreCheckedINTEL<
DecorT, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
Height, Width, Stride);
#else
std::ignore = sg;
std::ignore = Src;
Expand Down Expand Up @@ -797,12 +793,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
#if defined(__SYCL_DEVICE_ONLY__)
std::ignore = sg;
T *Ptr = Src.get();
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
spv_matrix_layout_traits<layout::dynamic>::value>(
Ptr, Stride, Height, Width, CoordX, CoordY,
sycl::detail::joint_matrix_layout_to_spv(Layout),
spv_scope_traits<Group>::value);
Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv(Layout),
Height, Width, Stride);
#else
std::ignore = sg;
std::ignore = Res;
Expand Down Expand Up @@ -832,11 +827,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
#if defined(__SYCL_DEVICE_ONLY__)
std::ignore = sg;
T *Ptr = Src.get();
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
Ptr, Stride, Height, Width, CoordX, CoordY,
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
Width, Stride);
#else
std::ignore = sg;
std::ignore = Res;
Expand All @@ -863,12 +858,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
#if defined(__SYCL_DEVICE_ONLY__)
std::ignore = sg;
T *Ptr = Dst.get();
__spirv_JointMatrixStoreCheckedINTEL<
__spirv_CooperativeMatrixStoreCheckedINTEL<
T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
spv_matrix_layout_traits<layout::dynamic>::value>(
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
sycl::detail::joint_matrix_layout_to_spv(Layout),
spv_scope_traits<Group>::value);
Ptr, CoordX, CoordY, Src.spvm,
sycl::detail::joint_matrix_layout_to_spv(Layout), Height, Width, Stride);
#else
std::ignore = sg;
std::ignore = Src;
Expand All @@ -894,11 +888,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
#if defined(__SYCL_DEVICE_ONLY__)
std::ignore = sg;
T *Ptr = Dst.get();
__spirv_JointMatrixStoreCheckedINTEL<T, Tp, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
__spirv_CooperativeMatrixStoreCheckedINTEL<
T, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
Height, Width, Stride);
#else
std::ignore = sg;
std::ignore = Src;
Expand Down
Loading