@@ -42,6 +42,19 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) {
42
42
});
43
43
}
44
44
45
+ #ifdef USE_ROCM
46
+ void bfloat16tofloat32_copy_kernel_cuda (TensorIteratorBase &iter) {
47
+ gpu_kernel_nocast (iter, [] GPU_LAMBDA (at::BFloat16 value) {
48
+ return static_cast <float >(value);
49
+ });
50
+ }
51
+ void float16tofloat32_copy_kernel_cuda (TensorIteratorBase &iter) {
52
+ gpu_kernel_nocast (iter, [] GPU_LAMBDA (at::Half value) {
53
+ return static_cast <float >(value);
54
+ });
55
+ }
56
+ #endif
57
+
45
58
void float8_copy_kernel_cuda (TensorIteratorBase &iter) {
46
59
ScalarType dtype = iter.dtype (0 );
47
60
ScalarType other_dtype = iter.dtype (1 );
@@ -187,7 +200,17 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
187
200
} else {
188
201
float16_copy_kernel_cuda (iter);
189
202
}
190
- } else if (isBitsType (dtype)) {
203
+ }
204
+ #ifdef USE_ROCM
205
+ else if ((iter.dtype (1 ) == kBFloat16 || iter.dtype (1 ) == kHalf ) && dtype == kFloat ) {
206
+ if (iter.dtype (1 ) == kBFloat16 ) {
207
+ bfloat16tofloat32_copy_kernel_cuda (iter);
208
+ } else {
209
+ float16tofloat32_copy_kernel_cuda (iter);
210
+ }
211
+ }
212
+ #endif
213
+ else if (isBitsType (dtype)) {
191
214
TORCH_CHECK (dtype == iter.dtype (1 ), " copy_() does not support casting "
192
215
" bits types to different bits types. Source dtype is " , iter.dtype (1 ), " target dtype is " , dtype);
193
216
AT_DISPATCH_BIT_TYPES (dtype, " copy_" , [&] {
0 commit comments