Skip to content

[SYCL][Matrix] Add support for tf32 type using the unified interface #8702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,23 @@
#ifdef __SYCL_DEVICE_ONLY__

#if (SYCL_EXT_ONEAPI_MATRIX_VERSION > 1)
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
extern __DPCPP_SYCL_EXTERNAL float __spirv_RoundFToTF32INTEL(float a);
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
__spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
__spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *Object,
T *Ptr, __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
std::size_t Stride, __spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

Expand Down Expand Up @@ -100,11 +104,13 @@ extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_CompositeConstruct(const T v);
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_CompositeConstruct(const T v);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand All @@ -119,18 +125,20 @@ template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
extern __DPCPP_SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
template <typename Ts, typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL T __spirv_VectorExtractDynamic(
extern __DPCPP_SYCL_EXTERNAL Ts __spirv_VectorExtractDynamic(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
template <typename Ts, typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
T val, size_t i);
Ts val, size_t i);
#else
template <typename T, std::size_t R, std::size_t C,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down
124 changes: 103 additions & 21 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,26 @@ struct joint_matrix;

} // namespace matrix
} // namespace experimental

namespace detail {
// Differentiating between the "element type" and the "storage element type"
template <typename T> struct jm_type_interpretation_helper_trait {
using element_type = T;
using storage_element_type = T;
};

template <>
struct jm_type_interpretation_helper_trait<
sycl::ext::oneapi::experimental::matrix::precision::tf32> {
using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32;
using storage_element_type = float;
};
} // namespace detail
} // namespace oneapi

namespace intel::experimental::matrix {

using namespace sycl::ext::oneapi::experimental::matrix;
// Begin wi_element definition

template <typename T, size_t NumRows, size_t NumCols,
Expand All @@ -84,6 +100,9 @@ class wi_element {
std::size_t idx;

public:
using storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T>::storage_element_type;
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, T, Use, NumRows, NumCols, Layout> &Mat,
std::size_t i)
Expand All @@ -102,9 +121,15 @@ class wi_element {
#endif // __SYCL_DEVICE_ONLY__
}

operator T() {
operator storage_element_type() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
storage_element_type elem =
__spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(M.spvm,
idx);
return elem;
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
Expand All @@ -113,7 +138,12 @@ class wi_element {

explicit operator bool() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast<T>(0);
return __spirv_VectorExtractDynamic<storage_element_type, T, NumRows,
NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(
M.spvm, idx) != static_cast<storage_element_type>(0);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
Expand All @@ -122,7 +152,8 @@ class wi_element {

template <typename T2> wi_element &operator=(const T2 &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast<T>(rhs), idx);
M.spvm = __spirv_VectorInsertDynamic(
M.spvm, static_cast<storage_element_type>(rhs), idx);
return *this;
#else
(void)rhs;
Expand All @@ -135,7 +166,13 @@ class wi_element {
operator=(const wi_element<T, NumRows, NumCols, Use, Layout, Group> &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
M.spvm = __spirv_VectorInsertDynamic(
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
M.spvm,
__spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(rhs.M.spvm,
rhs.idx),
idx);
return *this;
#else
(void)rhs;
Expand All @@ -149,8 +186,13 @@ class wi_element {
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
M.spvm = __spirv_VectorInsertDynamic( \
M.spvm, \
static_cast<T>(__spirv_VectorExtractDynamic(M.spvm, idx) \
op static_cast<T>(rhs)), \
static_cast<storage_element_type>( \
__spirv_VectorExtractDynamic< \
storage_element_type, T, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(M.spvm, idx) \
op static_cast<storage_element_type>(rhs)), \
idx); \
return *this; \
}
Expand Down Expand Up @@ -201,7 +243,11 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,

operator sycl::ext::oneapi::bfloat16() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
return __spirv_VectorExtractDynamic<
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows,
NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(M.spvm, idx);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
Expand All @@ -210,8 +256,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,

explicit operator bool() {
#ifdef __SYCL_DEVICE_ONLY__
return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic(
M.spvm, idx))) >= std::numeric_limits<float>::epsilon();
return std::fabs(static_cast<float>(
__spirv_VectorExtractDynamic<
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16,
NumRows, NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(M.spvm, idx))) >=
std::numeric_limits<float>::epsilon();
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
Expand All @@ -233,7 +284,14 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
NumCols, Use, Layout, Group> &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
M.spvm = __spirv_VectorInsertDynamic(
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
M.spvm,
__spirv_VectorExtractDynamic<sycl::ext::oneapi::bfloat16,
sycl::ext::oneapi::bfloat16, NumRows,
NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(rhs.M.spvm,
rhs.idx),
idx);
return *this;
#else
(void)rhs;
Expand All @@ -246,7 +304,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
#define OP(opassign, op) \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
M.spvm = __spirv_VectorInsertDynamic( \
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \
M.spvm, \
__spirv_VectorExtractDynamic< \
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
NumCols, spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(M.spvm, idx) op rhs, \
idx); \
return *this; \
}
#else // __SYCL_DEVICE_ONLY__
Expand All @@ -269,13 +333,21 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
Layout, Group> &lhs, \
const sycl::ext::oneapi::bfloat16 &rhs) { \
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \
return __spirv_VectorExtractDynamic< \
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
NumCols, spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx) op rhs; \
} \
friend type operator op( \
const sycl::ext::oneapi::bfloat16 &lhs, \
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
Layout, Group> &rhs) { \
return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \
return __spirv_VectorExtractDynamic< \
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
NumCols, spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx) op lhs; \
}
OP(sycl::ext::oneapi::bfloat16, +)
OP(sycl::ext::oneapi::bfloat16, -)
Expand All @@ -287,15 +359,25 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
Layout, Group> &lhs, \
const sycl::ext::oneapi::bfloat16 &rhs) { \
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \
return type{static_cast<float>( \
__spirv_VectorExtractDynamic< \
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
NumCols, spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx)) \
op static_cast<float>(rhs)}; \
} \
friend type operator op( \
const sycl::ext::oneapi::bfloat16 &lhs, \
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
Layout, Group> &rhs) { \
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \
return type{static_cast<float>( \
__spirv_VectorExtractDynamic< \
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
NumCols, spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx)) \
op static_cast<float>(lhs)}; \
}
OP(bool, ==)
OP(bool, !=)
Expand Down Expand Up @@ -386,7 +468,7 @@ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
// End wi_data definition

template <
typename Group, typename T,
typename Group, typename T, typename Tp,
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
size_t NumCols, sycl::ext::oneapi::experimental::matrix::layout Layout,
access::address_space Space, access::decorated IsDecorated,
Expand All @@ -396,7 +478,7 @@ template <
inline __SYCL_ALWAYS_INLINE void
joint_matrix_store(Group sg,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, T, Use, NumRows, NumCols, Layout> &src,
Group, Tp, Use, NumRows, NumCols, Layout> &src,
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
Expand All @@ -411,7 +493,7 @@ joint_matrix_store(Group sg,
#else
// intel's impl
T *Ptr = dst.get();
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
__spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
sycl::ext::oneapi::experimental::matrix::
spv_matrix_use_traits<Use>::value,
sycl::ext::oneapi::experimental::matrix::
Expand Down
6 changes: 0 additions & 6 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ namespace oneapi {
namespace experimental {
namespace matrix {

namespace precision {
class tf32 {
tf32() = delete;
};
} // namespace precision

template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
layout Layout = layout::dynamic>
struct joint_matrix;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ enum class use { a, b, accumulator };

enum class layout { row_major = 0, col_major = 1, dynamic = 3 };

namespace precision {
class tf32 {
tf32() = delete;
};
} // namespace precision

} // namespace matrix
} // namespace experimental
} // namespace oneapi
Expand Down
Loading