Skip to content

[SYCL][Matrix] Add support for tf32 type using the unified interface #8151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,23 @@
#ifdef __SYCL_DEVICE_ONLY__

#if (SYCL_EXT_ONEAPI_MATRIX_VERSION > 1)
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
extern SYCL_EXTERNAL float __spirv_ConvertFToTF32INTEL(float a);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
__spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *Object,
T *Ptr, __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
std::size_t Stride, __spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

Expand Down Expand Up @@ -95,10 +99,11 @@ __spirv_JointMatrixSUMadINTEL(
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_CompositeConstruct(const T v);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
Expand All @@ -107,18 +112,20 @@ template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
template <typename Ts, typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic(
extern SYCL_EXTERNAL Ts __spirv_VectorExtractDynamic(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
template <typename Ts, typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
T val, size_t i);
Ts val, size_t i);
#else
template <typename T, std::size_t R, std::size_t C,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down
119 changes: 98 additions & 21 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,39 @@ template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
layout Layout>
struct joint_matrix;

// Differentiating between the "element type" and the "storage element type"
template <typename T> struct helper_traits {
using element_type = T;
using storage_element_type = T;
using fill_argument_type = T;
};

template <> struct helper_traits<precision::tf32> {
using element_type = precision::tf32;
using storage_element_type = float;
using fill_argument_type = float;
};

template <typename T, size_t NumRows, size_t NumCols, use Use,
layout Layout = layout::dynamic, typename Group = sycl::sub_group>
class wi_element {
joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &M;
std::size_t idx;

public:
using storage_element_type = typename helper_traits<T>::storage_element_type;
wi_element(joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Mat,
std::size_t i)
: M(Mat), idx(i) {}
operator T() {
operator storage_element_type() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
storage_element_type elem =
__spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(M.spvm,
idx);
return elem;
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
Expand All @@ -86,7 +106,12 @@ class wi_element {

explicit operator bool() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast<T>(0);
return __spirv_VectorExtractDynamic<storage_element_type, T, NumRows,
NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(
M.spvm, idx) != static_cast<storage_element_type>(0);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
Expand All @@ -95,7 +120,8 @@ class wi_element {

template <typename T2> wi_element &operator=(const T2 &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast<T>(rhs), idx);
M.spvm = __spirv_VectorInsertDynamic(
M.spvm, static_cast<storage_element_type>(rhs), idx);
return *this;
#else
(void)rhs;
Expand All @@ -108,7 +134,13 @@ class wi_element {
operator=(const wi_element<T, NumRows, NumCols, Use, Layout, Group> &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
M.spvm = __spirv_VectorInsertDynamic(
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
M.spvm,
__spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(rhs.M.spvm,
rhs.idx),
idx);
return *this;
#else
(void)rhs;
Expand All @@ -122,8 +154,13 @@ class wi_element {
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
M.spvm = __spirv_VectorInsertDynamic( \
M.spvm, \
static_cast<T>(__spirv_VectorExtractDynamic(M.spvm, idx) \
op static_cast<T>(rhs)), \
static_cast<storage_element_type>( \
__spirv_VectorExtractDynamic< \
storage_element_type, T, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(M.spvm, idx) \
op static_cast<storage_element_type>(rhs)), \
idx); \
return *this; \
}
Expand Down Expand Up @@ -157,7 +194,11 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
: M(Mat), idx(i) {}
operator sycl::ext::oneapi::bfloat16() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
return __spirv_VectorExtractDynamic<
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows,
NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(M.spvm, idx);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
Expand All @@ -166,8 +207,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,

explicit operator bool() {
#ifdef __SYCL_DEVICE_ONLY__
return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic(
M.spvm, idx))) >= std::numeric_limits<float>::epsilon();
return std::fabs(static_cast<float>(
__spirv_VectorExtractDynamic<
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16,
NumRows, NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(M.spvm, idx))) >=
std::numeric_limits<float>::epsilon();
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
Expand All @@ -189,7 +235,14 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
NumCols, Use, Layout, Group> &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
M.spvm = __spirv_VectorInsertDynamic(
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
M.spvm,
__spirv_VectorExtractDynamic<sycl::ext::oneapi::bfloat16,
sycl::ext::oneapi::bfloat16, NumRows,
NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value,
spv_scope_traits<Group>::value>(rhs.M.spvm,
rhs.idx),
idx);
return *this;
#else
(void)rhs;
Expand All @@ -202,7 +255,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
#define OP(opassign, op) \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
M.spvm = __spirv_VectorInsertDynamic( \
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \
M.spvm, \
__spirv_VectorExtractDynamic< \
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
NumCols, spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(M.spvm, idx) op rhs, \
idx); \
return *this; \
}
#else // __SYCL_DEVICE_ONLY__
Expand All @@ -225,13 +284,21 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
Layout, Group> &lhs, \
const sycl::ext::oneapi::bfloat16 &rhs) { \
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \
return __spirv_VectorExtractDynamic< \
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
NumCols, spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx) op rhs; \
} \
friend type operator op( \
const sycl::ext::oneapi::bfloat16 &lhs, \
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
Layout, Group> &rhs) { \
return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \
return __spirv_VectorExtractDynamic< \
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
NumCols, spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx) op lhs; \
}
OP(sycl::ext::oneapi::bfloat16, +)
OP(sycl::ext::oneapi::bfloat16, -)
Expand All @@ -243,15 +310,25 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
Layout, Group> &lhs, \
const sycl::ext::oneapi::bfloat16 &rhs) { \
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \
return type{static_cast<float>( \
__spirv_VectorExtractDynamic< \
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
NumCols, spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx)) \
op static_cast<float>(rhs)}; \
} \
friend type operator op( \
const sycl::ext::oneapi::bfloat16 &lhs, \
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
Layout, Group> &rhs) { \
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \
return type{static_cast<float>( \
__spirv_VectorExtractDynamic< \
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
NumCols, spv_matrix_use_traits<Use>::value, \
spv_matrix_layout_traits<Layout>::value, \
spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx)) \
op static_cast<float>(lhs)}; \
}
OP(bool, ==)
OP(bool, !=)
Expand Down Expand Up @@ -296,7 +373,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,

namespace intel::experimental::matrix {
template <
typename Group, typename T,
typename Group, typename T, typename Tp,
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
size_t NumCols, sycl::ext::oneapi::experimental::matrix::layout Layout,
access::address_space Space, access::decorated IsDecorated,
Expand All @@ -306,7 +383,7 @@ template <
inline __SYCL_ALWAYS_INLINE void
joint_matrix_store(Group sg,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, T, Use, NumRows, NumCols, Layout> &src,
Group, Tp, Use, NumRows, NumCols, Layout> &src,
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
Expand All @@ -321,7 +398,7 @@ joint_matrix_store(Group sg,
#else
// intel's impl
T *Ptr = dst.get();
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
__spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
sycl::ext::oneapi::experimental::matrix::
spv_matrix_use_traits<Use>::value,
sycl::ext::oneapi::experimental::matrix::
Expand Down
6 changes: 0 additions & 6 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ namespace oneapi {
namespace experimental {
namespace matrix {

namespace precision {
class tf32 {
tf32() = delete;
};
} // namespace precision

template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
layout Layout = layout::dynamic>
struct joint_matrix;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ enum class use { a, b, accumulator };

enum class layout { row_major = 0, col_major = 1, dynamic = 3 };

namespace precision {
class tf32 {
tf32() = delete;
};
} // namespace precision

} // namespace matrix
} // namespace experimental
} // namespace oneapi
Expand Down
Loading