diff --git a/src/ATen/native/xpu/sycl/HermitePolynomialHKernel.cpp b/src/ATen/native/xpu/sycl/HermitePolynomialHKernel.cpp index 4862af9a8f..bffd6117b5 100644 --- a/src/ATen/native/xpu/sycl/HermitePolynomialHKernel.cpp +++ b/src/ATen/native/xpu/sycl/HermitePolynomialHKernel.cpp @@ -9,17 +9,15 @@ namespace at::native::xpu { template struct HermitePolynomialHFunctor { scalar_t operator()(scalar_t x, scalar_t n_) const { - int64_t n = static_cast(n_); + auto n = static_cast(n_); if (n < 0) { return scalar_t(0.0); - } - - if (n == 0) { + } else if (n == 0) { return scalar_t(1.0); - } - - if (n == 1) { + } else if (n == 1) { return x + x; + } else if (n > getHermitianLimit()) { + return std::numeric_limits::quiet_NaN(); } scalar_t p = scalar_t(1.0); diff --git a/src/ATen/native/xpu/sycl/HermitePolynomialHeKernel.cpp b/src/ATen/native/xpu/sycl/HermitePolynomialHeKernel.cpp index 6c37e87202..8ba1d8bfff 100644 --- a/src/ATen/native/xpu/sycl/HermitePolynomialHeKernel.cpp +++ b/src/ATen/native/xpu/sycl/HermitePolynomialHeKernel.cpp @@ -9,17 +9,15 @@ namespace at::native::xpu { template struct HermitePolynomialHeFunctor { scalar_t operator()(scalar_t x, scalar_t n_) const { - int64_t n = static_cast(n_); + auto n = static_cast(n_); if (n < 0) { return scalar_t(0.0); - } - - if (n == 0) { + } else if (n == 0) { return scalar_t(1.0); - } - - if (n == 1) { + } else if (n == 1) { return x; + } else if (n > getHermitianLimit()) { + return std::numeric_limits::quiet_NaN(); } scalar_t p = scalar_t(1.0); diff --git a/src/ATen/native/xpu/sycl/LaguerrePolynomialLKernel.cpp b/src/ATen/native/xpu/sycl/LaguerrePolynomialLKernel.cpp index 8688f80456..869d64492a 100644 --- a/src/ATen/native/xpu/sycl/LaguerrePolynomialLKernel.cpp +++ b/src/ATen/native/xpu/sycl/LaguerrePolynomialLKernel.cpp @@ -30,7 +30,7 @@ struct LaguerrePolynomialLFunctor { scalar_t q = scalar_t(1.0) - x; scalar_t r; - for (int64_t k = 1; k < n; k++) { + for (int64_t k = 1; (k < n) && !std::isnan(q); k++) { r = (((k + k) + (scalar_t(1.0) - x)) * q - k * p) / (k + 1); p = q; q = r; diff --git a/src/ATen/native/xpu/sycl/LegendrePolynomialPKernel.cpp b/src/ATen/native/xpu/sycl/LegendrePolynomialPKernel.cpp index 2176c4e79c..caf1d8ac16 100644 --- a/src/ATen/native/xpu/sycl/LegendrePolynomialPKernel.cpp +++ b/src/ATen/native/xpu/sycl/LegendrePolynomialPKernel.cpp @@ -34,7 +34,7 @@ struct LegendrePolynomialPFunctor { scalar_t q = x; scalar_t r; - for (int64_t k = 1; k < n; k++) { + for (int64_t k = 1; (k < n) && !std::isnan(q); k++) { r = ((k + k + 1) * x * q - k * p) / (k + 1); p = q; q = r; diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 5e70eebf44..9eeb6b44c4 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -74,12 +74,6 @@ "test_compare_cpu_linalg_lu_factor_ex_xpu_float32", "test_compare_cpu_linalg_lu_factor_xpu_float32", "test_compare_cpu_linalg_lu_xpu_float32", - # XPU hang. CUDA hang as well. - # https://github.com/pytorch/pytorch/issues/79528 - "test_compare_cpu_special_hermite_polynomial_he_xpu_float32", - "test_compare_cpu_special_hermite_polynomial_h_xpu_float32", - "test_compare_cpu_special_laguerre_polynomial_l_xpu_float32", - "test_compare_cpu_special_legendre_polynomial_p_xpu_float32", # core dump "test_dtypes__refs_nn_functional_pdist_xpu", # XFAIL of CUDA and XPU, unexpected success in fallback