Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -855,13 +855,13 @@ def INT_NVVM_FABS_D : F_MATH_1<"abs.f64 \t$dst, $src0;", Float64Regs,
// Abs, Neg bf16, bf16x2
//

def INT_NVVM_ABS_BF16 : F_MATH_1<"abs.bf16 \t$dst, $dst;", Int16Regs,
def INT_NVVM_ABS_BF16 : F_MATH_1<"abs.bf16 \t$dst, $src0;", Int16Regs,
Int16Regs, int_nvvm_abs_bf16, [hasPTX70, hasSM80]>;
def INT_NVVM_ABS_BF16X2 : F_MATH_1<"abs.bf16x2 \t$dst, $dst;", Int32Regs,
def INT_NVVM_ABS_BF16X2 : F_MATH_1<"abs.bf16x2 \t$dst, $src0;", Int32Regs,
Int32Regs, int_nvvm_abs_bf16x2, [hasPTX70, hasSM80]>;
def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $dst;", Int16Regs,
def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $src0;", Int16Regs,
Int16Regs, int_nvvm_neg_bf16, [hasPTX70, hasSM80]>;
def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $dst;", Int32Regs,
def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $src0;", Int32Regs,
Int32Regs, int_nvvm_neg_bf16x2, [hasPTX70, hasSM80]>;

//
Expand Down
24 changes: 24 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,30 @@ extern SYCL_EXTERNAL __ocl_vec_t<_Float16, 8>
extern SYCL_EXTERNAL __ocl_vec_t<_Float16, 16>
__clc_native_exp2(__ocl_vec_t<_Float16, 16>);

#define __CLC_BF16(...) \
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fabs( \
__VA_ARGS__) noexcept; \
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmin( \
__VA_ARGS__, __VA_ARGS__) noexcept; \
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmax( \
__VA_ARGS__, __VA_ARGS__) noexcept; \
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fma( \
__VA_ARGS__, __VA_ARGS__, __VA_ARGS__) noexcept;

#define __CLC_BF16_SCAL_VEC(TYPE) \
__CLC_BF16(TYPE) \
__CLC_BF16(__ocl_vec_t<TYPE, 2>) \
__CLC_BF16(__ocl_vec_t<TYPE, 3>) \
__CLC_BF16(__ocl_vec_t<TYPE, 4>) \
__CLC_BF16(__ocl_vec_t<TYPE, 8>) \
__CLC_BF16(__ocl_vec_t<TYPE, 16>)

__CLC_BF16_SCAL_VEC(uint16_t)
__CLC_BF16_SCAL_VEC(uint32_t)

#undef __CLC_BF16_SCAL_VEC
#undef __CLC_BF16

#else // if !__SYCL_DEVICE_ONLY__

template <typename dataT>
Expand Down
1 change: 1 addition & 0 deletions sycl/include/CL/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#if SYCL_EXT_ONEAPI_BACKEND_LEVEL_ZERO
#include <sycl/ext/oneapi/backend/level_zero.hpp>
#endif
#include <sycl/ext/oneapi/bf16_storage_builtins.hpp>
#include <sycl/ext/oneapi/experimental/builtins.hpp>
#include <sycl/ext/oneapi/filter_selector.hpp>
#include <sycl/ext/oneapi/group_algorithm.hpp>
Expand Down
79 changes: 79 additions & 0 deletions sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#pragma once

#include <CL/__spirv/spirv_ops.hpp>
#include <CL/sycl/builtins.hpp>
#include <CL/sycl/detail/builtins.hpp>
#include <CL/sycl/detail/generic_type_lists.hpp>
#include <CL/sycl/detail/generic_type_traits.hpp>
#include <CL/sycl/detail/type_traits.hpp>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace ext {
namespace oneapi {

namespace detail {

template <typename T> struct is_bf16_storage_type {
static constexpr int value = false;
};

template <> struct is_bf16_storage_type<uint16_t> {
static constexpr int value = true;
};

template <> struct is_bf16_storage_type<uint32_t> {
static constexpr int value = true;
};

template <int N> struct is_bf16_storage_type<vec<uint16_t, N>> {
static constexpr int value = true;
};

template <int N> struct is_bf16_storage_type<vec<uint32_t, N>> {
static constexpr int value = true;
};

} // namespace detail

template <typename T>
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fabs(T x) {
#ifdef __SYCL_DEVICE_ONLY__
return __clc_fabs(x);
#else
throw runtime_error("bf16 is not supported on host device.",
PI_INVALID_DEVICE);
#endif
}
template <typename T>
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fmin(T x, T y) {
#ifdef __SYCL_DEVICE_ONLY__
return __clc_fmin(x, y);
#else
throw runtime_error("bf16 is not supported on host device.",
PI_INVALID_DEVICE);
#endif
}
template <typename T>
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fmax(T x, T y) {
#ifdef __SYCL_DEVICE_ONLY__
return __clc_fmax(x, y);
#else
throw runtime_error("bf16 is not supported on host device.",
PI_INVALID_DEVICE);
#endif
}
template <typename T>
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fma(T x, T y, T z) {
#ifdef __SYCL_DEVICE_ONLY__
return __clc_fma(x, y, z);
#else
throw runtime_error("bf16 is not supported on host device.",
PI_INVALID_DEVICE);
#endif
}

} // namespace oneapi
} // namespace ext
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)