Skip to content

Commit d5a81e0

Browse files
authored
Fix overflow when calculating workgroups count (#2104)
To fix #2070. This PR updates several SYCL kernel launch functions in `src/ATen/native/xpu/sycl/Loops.h` to use `int64_t` for workgroup size and number of workgroups calculations. This change prevents overflow issues when handling large tensor sizes.
1 parent fa1e391 commit d5a81e0

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

src/ATen/native/xpu/sycl/Loops.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ static void launch_legacy_group_range_kernel(int64_t N, const func_t& f) {
314314

315315
auto ker = ElementwiseGroupRangeKernel<vec_size, func_t>(N, f);
316316

317-
int wg_sz = syclMaxWorkItemsPerSubSlice();
318-
int num_wg = ceil_div<int>(N, wg_sz * vec_size);
317+
int64_t wg_sz = syclMaxWorkItemsPerSubSlice();
318+
int64_t num_wg = ceil_div<int64_t>(N, wg_sz * vec_size);
319319
sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker);
320320
}
321321

@@ -328,9 +328,9 @@ static void launch_legacy_global_range_kernel(int64_t N, const func_t& f) {
328328

329329
auto ker = ElementwiseGlobalRangeKernel<func_t>(N, f);
330330

331-
int wg_sz = syclMaxWorkItemsPerSubSlice();
332-
int num_wg = ceil_div<int>(N, wg_sz);
333-
int hw_max_num_wg = syclMaxWorkItemsPerTile() / wg_sz;
331+
int64_t wg_sz = syclMaxWorkItemsPerSubSlice();
332+
int64_t num_wg = ceil_div<int64_t>(N, wg_sz);
333+
int64_t hw_max_num_wg = syclMaxWorkItemsPerTile() / wg_sz;
334334
num_wg = num_wg > hw_max_num_wg ? hw_max_num_wg : num_wg;
335335
sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker);
336336
}
@@ -355,8 +355,8 @@ static inline void launch_unrolled_kernel(
355355
auto ker = UnrolledElementwiseKernel(N, f, data, ic, oc, l, s);
356356
using ker_t = decltype(ker);
357357

358-
auto wg_sz = syclMaxWorkItemsPerSubSlice();
359-
int num_wg = ceil_div<int>(N, wg_sz * ker_t::item_work_size);
358+
int64_t wg_sz = syclMaxWorkItemsPerSubSlice();
359+
int64_t num_wg = ceil_div<int64_t>(N, wg_sz * ker_t::item_work_size);
360360
sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker);
361361
}
362362

@@ -393,13 +393,13 @@ static inline void launch_vectorized_kernel(
393393

394394
#define VEC_KER(vec_size) \
395395
{ \
396-
TORCH_CHECK(max_scalar_bytes* vec_size <= 16); \
396+
TORCH_CHECK(max_scalar_bytes * vec_size <= 16); \
397397
if constexpr (max_scalar_bytes * vec_size <= 16) { \
398398
auto ker = \
399399
VectorizedElementwiseKernel<vec_size, func_t, array_t, in_calc_t>( \
400400
N, f, data, input_calc); \
401-
int num_wg = ceil_div<int>(N, wg_sz * vec_size); \
402-
sycl_kernel_submit(wg_sz* num_wg, wg_sz, getCurrentSYCLQueue(), ker); \
401+
int64_t num_wg = ceil_div<int64_t>(N, wg_sz * vec_size); \
402+
sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker); \
403403
} \
404404
}
405405

@@ -426,7 +426,7 @@ static inline void launch_vectorized_kernel(
426426
N, f, data, input_calc, output_calc, loader, storer);
427427
using ker_t = decltype(ker);
428428

429-
int num_wg = ceil_div<int>(N, wg_sz * ker_t::item_work_size);
429+
int64_t num_wg = ceil_div<int64_t>(N, wg_sz * ker_t::item_work_size);
430430
sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker);
431431
break;
432432
}
@@ -457,8 +457,8 @@ static inline void launch_unrolled_kernel_for_multi_outputs(
457457
out_calc_t>(N, f, data, ic, oc);
458458
using ker_t = decltype(ker);
459459

460-
int wg_sz = syclMaxWorkItemsPerSubSlice();
461-
int num_wg = ceil_div<int>(N, ker_t::item_work_size * wg_sz);
460+
int64_t wg_sz = syclMaxWorkItemsPerSubSlice();
461+
int64_t num_wg = ceil_div<int64_t>(N, ker_t::item_work_size * wg_sz);
462462
sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker);
463463
}
464464

test/xpu/test_tensor_creation_ops_xpu.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4371,6 +4371,14 @@ def test_full_like_inference(self, device):
43714371
torch.full_like(like, 1.0, dtype=torch.complex64).dtype, torch.complex64
43724372
)
43734373

4374+
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
4375+
def test_zeros_large(self, device, dtype):
4376+
output = torch.zeros(2**31 - 1, device=device, dtype=dtype)
4377+
4378+
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
4379+
def test_ones_large(self, device, dtype):
4380+
output = torch.ones(2**31 - 1, device=device, dtype=dtype)
4381+
43744382

43754383
# Tests for the `frombuffer` function (only work on CPU):
43764384
# Constructs tensors from Python objects that implement the buffer protocol,

0 commit comments

Comments
 (0)