@@ -314,8 +314,8 @@ static void launch_legacy_group_range_kernel(int64_t N, const func_t& f) {
314314
315315 auto ker = ElementwiseGroupRangeKernel<vec_size, func_t >(N, f);
316316
317- int wg_sz = syclMaxWorkItemsPerSubSlice ();
318- int num_wg = ceil_div<int >(N, wg_sz * vec_size);
317+ int64_t wg_sz = syclMaxWorkItemsPerSubSlice ();
318+ int64_t num_wg = ceil_div<int64_t >(N, wg_sz * vec_size);
319319 sycl_kernel_submit (wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker);
320320}
321321
@@ -328,9 +328,9 @@ static void launch_legacy_global_range_kernel(int64_t N, const func_t& f) {
328328
329329 auto ker = ElementwiseGlobalRangeKernel<func_t >(N, f);
330330
331- int wg_sz = syclMaxWorkItemsPerSubSlice ();
332- int num_wg = ceil_div<int >(N, wg_sz);
333- int hw_max_num_wg = syclMaxWorkItemsPerTile () / wg_sz;
331+ int64_t wg_sz = syclMaxWorkItemsPerSubSlice ();
332+ int64_t num_wg = ceil_div<int64_t >(N, wg_sz);
333+ int64_t hw_max_num_wg = syclMaxWorkItemsPerTile () / wg_sz;
334334 num_wg = num_wg > hw_max_num_wg ? hw_max_num_wg : num_wg;
335335 sycl_kernel_submit (wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker);
336336}
@@ -355,8 +355,8 @@ static inline void launch_unrolled_kernel(
355355 auto ker = UnrolledElementwiseKernel (N, f, data, ic, oc, l, s);
356356 using ker_t = decltype (ker);
357357
358- auto wg_sz = syclMaxWorkItemsPerSubSlice ();
359- int num_wg = ceil_div<int >(N, wg_sz * ker_t ::item_work_size);
358+ int64_t wg_sz = syclMaxWorkItemsPerSubSlice ();
359+ int64_t num_wg = ceil_div<int64_t >(N, wg_sz * ker_t ::item_work_size);
360360 sycl_kernel_submit (wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker);
361361}
362362
@@ -393,13 +393,13 @@ static inline void launch_vectorized_kernel(
393393
394394#define VEC_KER (vec_size ) \
395395 { \
396- TORCH_CHECK (max_scalar_bytes* vec_size <= 16 ); \
396+ TORCH_CHECK (max_scalar_bytes * vec_size <= 16 ); \
397397 if constexpr (max_scalar_bytes * vec_size <= 16 ) { \
398398 auto ker = \
399399 VectorizedElementwiseKernel<vec_size, func_t , array_t , in_calc_t >( \
400400 N, f, data, input_calc); \
401- int num_wg = ceil_div<int >(N, wg_sz * vec_size); \
402- sycl_kernel_submit (wg_sz* num_wg, wg_sz, getCurrentSYCLQueue (), ker); \
401+ int64_t num_wg = ceil_div<int64_t >(N, wg_sz * vec_size); \
402+ sycl_kernel_submit (wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker); \
403403 } \
404404 }
405405
@@ -426,7 +426,7 @@ static inline void launch_vectorized_kernel(
426426 N, f, data, input_calc, output_calc, loader, storer);
427427 using ker_t = decltype (ker);
428428
429- int num_wg = ceil_div<int >(N, wg_sz * ker_t ::item_work_size);
429+ int64_t num_wg = ceil_div<int64_t >(N, wg_sz * ker_t ::item_work_size);
430430 sycl_kernel_submit (wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker);
431431 break ;
432432 }
@@ -457,8 +457,8 @@ static inline void launch_unrolled_kernel_for_multi_outputs(
457457 out_calc_t >(N, f, data, ic, oc);
458458 using ker_t = decltype (ker);
459459
460- int wg_sz = syclMaxWorkItemsPerSubSlice ();
461- int num_wg = ceil_div<int >(N, ker_t ::item_work_size * wg_sz);
460+ int64_t wg_sz = syclMaxWorkItemsPerSubSlice ();
461+ int64_t num_wg = ceil_div<int64_t >(N, ker_t ::item_work_size * wg_sz);
462462 sycl_kernel_submit (wg_sz * num_wg, wg_sz, getCurrentSYCLQueue (), ker);
463463}
464464
0 commit comments