Skip to content

Commit 167f7e2

Browse files
authored
[ROCm] Implement float32 copy kernel (#2682)
cherry-pick of pytorch#163869
1 parent 99ccf24 commit 167f7e2

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

aten/src/ATen/native/cuda/Copy.cu

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) {
4242
});
4343
}
4444

45+
#ifdef USE_ROCM
46+
void bfloat16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
47+
gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::BFloat16 value) {
48+
return static_cast<float>(value);
49+
});
50+
}
51+
void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
52+
gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::Half value) {
53+
return static_cast<float>(value);
54+
});
55+
}
56+
#endif
57+
4558
void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
4659
ScalarType dtype = iter.dtype(0);
4760
ScalarType other_dtype = iter.dtype(1);
@@ -187,7 +200,17 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
187200
} else {
188201
float16_copy_kernel_cuda(iter);
189202
}
190-
} else if (isBitsType(dtype)) {
203+
}
204+
#ifdef USE_ROCM
205+
else if ((iter.dtype(1) == kBFloat16 || iter.dtype(1) == kHalf) && dtype == kFloat) {
206+
if (iter.dtype(1) == kBFloat16) {
207+
bfloat16tofloat32_copy_kernel_cuda(iter);
208+
} else {
209+
float16tofloat32_copy_kernel_cuda(iter);
210+
}
211+
}
212+
#endif
213+
else if (isBitsType(dtype)) {
191214
TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
192215
"bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
193216
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {

0 commit comments

Comments
 (0)