Skip to content

Commit fde03de

Browse files
authored
fix compiling error (#98)
fix compiling error
1 parent 7bb2d01 commit fde03de

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

csrc/utils.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,11 @@ T __device__ __forceinline__ ADD2(T a, T b) {
256256
template <typename T>
257257
T __device__ __forceinline__ ZERO_VALUE(T a) {
258258
if constexpr (std::is_same<T, __bfloat16>::value) {
259-
return __ushort_as_bfloat16((unsigned short)0x0000U);
259+
#if defined(USE_ROCM)
260+
return __float2bfloat16(0.0f);
261+
#else
262+
return __float2bfloat16_rn(0.0f);
263+
#endif
260264
} else if constexpr (std::is_same<T, float>::value) {
261265
return 0.0f;
262266
} else {

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def build_cuda_extensions():
4646
arch_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
4747
print(" build for compute capabilities: ==============", compute_capabilities)
4848

49+
# set nvcc threads
50+
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
51+
4952
extra_compile_args = {
5053
"nvcc": [
5154
"-O3",
@@ -58,6 +61,7 @@ def build_cuda_extensions():
5861
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
5962
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
6063
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
64+
f"--threads={nvcc_threads}",
6165
] + arch_flags,
6266
"cxx": ["-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
6367
}

0 commit comments

Comments
 (0)