Skip to content

Commit fa1e391

Browse files
authored
Fix unnecessary double data type conversion (#2114)
# Motivation The original issue occurs on some old iGPU running the following code on Windows: ```python import torch import torch.nn.functional as F print(torch.xpu.get_device_properties()) arr = torch.rand(1, 2, 5, 5, device='xpu') pts = torch.rand(1, 3, 3, 2, device='xpu') out = F.grid_sample(arr, pts, align_corners=False) ``` The failure output is: ```bash Traceback (most recent call last): File "C:\Vesuvius\urerr\urerr.py", line 22, in <module> out = F.grid_sample(arr, pts, align_corners=False) File "C:\Anaconda3\envs\pytn\Lib\site-packages\torch\nn\functional.py", line 5118, in grid_sample return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: UR error ``` The Driver team analysis located the crash in the generated spriv IR code. ```asm ; Function Attrs: nounwind define internal spir_func float @_ZN2at6native3xpuL19compute_coordinatesIfEET_S3_iNS0_6detail18GridSamplerPaddingEb(float %0, i32 %1, i32 %2, i1 zeroext %3) #0 !spirv.ParameterDecorations !393 { %5 = alloca double, align 8, !spirv.Decorations !394 switch i32 %2, label %58 [ i32 1, label %6 i32 2, label %14 ] 6: ; preds = %4 %7 = sext i32 %1 to i64 %8 = add nsw i64 %7, -1, !spirv.Decorations !387 %9 = sitofp i64 %8 to float %10 = fcmp olt float %0, 0.000000e+00 %11 = select i1 %10, float 0.000000e+00, float %0 %12 = fcmp olt float %11, %9 %13 = select i1 %12, float %11, float %9 br label %58 14: ; preds = %4 br i1 %3, label %15, label %32 15: ; preds = %14 %16 = shl i32 %1, 1 %17 = add i32 %16, -2 %18 = icmp eq i32 %17, 0 br i1 %18, label %49, label %19 19: ; preds = %15 %20 = sitofp i32 %17 to float %21 = fmul float %20, 5.000000e-01 %22 = call spir_func float @_Z16__spirv_ocl_fabsf(float %0) #0 %23 = call spir_func float @fmodf(float %22, float %21) #3 %24 = fdiv float %22, %21 %25 = call spir_func float @_Z17__spirv_ocl_floorf(float %24) #0 %26 = fptosi float %25 to i32 %27 = and i32 %26, 1 %28 = icmp eq i32 %27, 0 %29 = fsub float %21, %23 %30 = select i1 %28, float %23, float %29 %31 = fadd float %30, 0.000000e+00 br label %49 32: ; preds = %14 %33 = icmp eq i32 %1, 0 br i1 %33, label %49, label %34 34: ; preds = %32 %35 = shl nsw i32 %1, 1, !spirv.Decorations !387 %36 = sitofp i32 %35 to float %37 = fmul float %36, 5.000000e-01 %38 = fadd float %0, 5.000000e-01 %39 = call spir_func float @_Z16__spirv_ocl_fabsf(float %38) #0 %40 = call spir_func float @fmodf(float %39, float %37) #3 %41 = fdiv float %39, %37 %42 = call spir_func float @_Z17__spirv_ocl_floorf(float %41) #0 %43 = fptosi float %42 to i32 %44 = and i32 %43, 1 %45 = icmp eq i32 %44, 0 %46 = fsub float %37, %40 %47 = select i1 %45, float %40, float %46 %48 = fadd float %47, -5.000000e-01 br label %49 49: ; preds = %34, %32, %19, %15 %50 = phi float [ %31, %19 ], [ 0.000000e+00, %15 ], [ %48, %34 ], [ 0.000000e+00, %32 ] %51 = sext i32 %1 to i64 %52 = add nsw i64 %51, -1, !spirv.Decorations !387 %53 = sitofp i64 %52 to float %54 = fcmp olt float %50, 0.000000e+00 %55 = select i1 %54, float 0.000000e+00, float %50 %56 = fcmp olt float %55, %53 %57 = select i1 %56, float %55, float %53 br label %58 58: ; preds = %49, %6, %4 %59 = phi float [ %13, %6 ], [ %57, %49 ], [ %0, %4 ] %60 = fptosi float %59 to i64 %61 = icmp sgt i64 %60, 2147483646 %62 = fcmp olt float %59, 0xC1E0000000000000 %63 = or i1 %61, %62 br i1 %63, label %72, label %64 64: ; preds = %58 %65 = fpext float %59 to double %66 = bitcast double* %5 to i8* call void @llvm.lifetime.start.p0i8(i64 8, i8* %66) %67 = addrspacecast double* %5 to double addrspace(4)* store double %65, double* %5, align 8 %68 = call spir_func signext i16 @_dtest(double addrspace(4)* %67) #3 %69 = bitcast double* %5 to i8* call void @llvm.lifetime.end.p0i8(i64 8, i8* %69) %70 = icmp slt i16 %68, 1 %71 = select i1 %70, float %59, float -1.000000e+02 br label %72 72: ; preds = %64, %58 %73 = phi float [ %71, %64 ], [ -1.000000e+02, %58 ] ret float %73 } ``` We can see that spirv IR code uses a double type and calls the @_dtest function in block 64. Accroding to [MSVC document](https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/floating-point-primitives?view=msvc-170#_dtest-_ldtest-_fdtest), _dtest is used to detect whether a number is `Nan` or `INFINITE`. This allows us to locate the root cause of the crash, which corresponds to the following C++ logic: ```cpp if (static_cast<int64_t>(x) > INT_MAX - 1 || x < INT_MIN || !std::isfinite(static_cast<double>(x))) return static_cast<scalar_t>(-100.0); return x; ``` In other words, the crash occurs when the GPU executes code that tries to convert a floating-point value (Half or BFloat16) to a double and check whether it is finite. # Solution - For Half and BFloat16, `std::isfinite(x)` promot `x` to `float`, providing enough precision for finiteness checks. Casting to double is redundant and can be safely removed. - Explicitly return `-100.0f` instead of double type. # Additional Context I can't find the iGPU that could verify the fix, but it is unlikely to introduce any additional error.
1 parent 09edbee commit fa1e391

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/ATen/native/xpu/sycl/GridSampler.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ static inline scalar_t safe_downgrade_to_int_range(scalar_t x) {
4949
// -100.0 does not have special meaning. This is just to make sure
5050
// it's not within_bounds_2d or within_bounds_3d, and does not cause
5151
// undefined behavior.
52+
// We avoid using double here because some platforms may not support it.
5253
if (static_cast<int64_t>(x) > INT_MAX - 1 || x < INT_MIN ||
53-
!std::isfinite(static_cast<double>(x)))
54-
return static_cast<scalar_t>(-100.0);
54+
!std::isfinite(x))
55+
return static_cast<scalar_t>(-100.0f);
5556
return x;
5657
}
5758

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Owner(s): ["module: intel"]
2+
import torch
3+
import torch.nn.functional as F
4+
from torch.testing._internal.common_utils import TestCase
5+
6+
cpu_device = torch.device("cpu")
7+
xpu_device = torch.device("xpu")
8+
9+
10+
class TestSimpleCopy(TestCase):
11+
# Refer to https://github.com/pytorch/pytorch/issues/153996
12+
def test_grid_sample(self, dtype=torch.float):
13+
input_cpu = torch.rand(1, 2, 5, 5, device=cpu_device)
14+
grid_cpu = torch.rand(1, 3, 3, 2, device=cpu_device)
15+
out_cpu = F.grid_sample(input_cpu, grid_cpu, align_corners=False)
16+
input_xpu = input_cpu.to(xpu_device)
17+
grid_xpu = grid_cpu.to(xpu_device)
18+
out_xpu = F.grid_sample(input_xpu, grid_xpu, align_corners=False)
19+
self.assertEqual(out_cpu, out_xpu.to(cpu_device))

0 commit comments

Comments
 (0)