diff --git a/sycl/include/sycl/detail/id_queries_fit_in_int.hpp b/sycl/include/sycl/detail/id_queries_fit_in_int.hpp index 3f12b47bd429..34073e3177d4 100644 --- a/sycl/include/sycl/detail/id_queries_fit_in_int.hpp +++ b/sycl/include/sycl/detail/id_queries_fit_in_int.hpp @@ -34,46 +34,88 @@ inline namespace _V1 { namespace detail { #if __SYCL_ID_QUERIES_FIT_IN_INT__ -template struct NotIntMsg; +constexpr static const char *Msg = + "Provided range and/or offset does not fit in int. Pass " + "`-fno-sycl-id-queries-fit-in-int' to remove this limit."; -template struct NotIntMsg> { - constexpr static const char *Msg = - "Provided range is out of integer limits. Pass " - "`-fno-sycl-id-queries-fit-in-int' to disable range check."; -}; - -template struct NotIntMsg> { - constexpr static const char *Msg = - "Provided offset is out of integer limits. Pass " - "`-fno-sycl-id-queries-fit-in-int' to disable offset check."; -}; - -template +template typename std::enable_if_t::value || std::is_same::value> checkValueRangeImpl(ValT V) { static constexpr size_t Limit = static_cast((std::numeric_limits::max)()); if (V > Limit) - throw sycl::exception(make_error_code(errc::nd_range), NotIntMsg::Msg); + throw sycl::exception(make_error_code(errc::nd_range), Msg); +} + +inline void checkMulOverflow(size_t a, size_t b) { +#ifndef _MSC_VER + int Product; + // Since we must fit in SIGNED int, we can ignore the upper 32 bits. + if (__builtin_mul_overflow(unsigned(a), unsigned(b), &Product)) { + throw sycl::exception(make_error_code(errc::nd_range), Msg); + } +#else + checkValueRangeImpl(a); + checkValueRangeImpl(b); + size_t Product = a * b; + checkValueRangeImpl(Product); +#endif +} + +inline void checkMulOverflow(size_t a, size_t b, size_t c) { +#ifndef _MSC_VER + int Product; + // Since we must fit in SIGNED int, we can ignore the upper 32 bits. + if (__builtin_mul_overflow(unsigned(a), unsigned(b), &Product) || + __builtin_mul_overflow(Product, unsigned(c), &Product)) { + throw sycl::exception(make_error_code(errc::nd_range), Msg); + } +#else + checkValueRangeImpl(a); + checkValueRangeImpl(b); + size_t Product = a * b; + checkValueRangeImpl(Product); + + checkValueRangeImpl(c); + Product *= c; + checkValueRangeImpl(Product); +#endif +} + +// TODO: Remove this function when offsets are removed. +template +inline bool hasNonZeroOffset(const sycl::nd_range &V) { + size_t Product = 1; + for (int Dim = 0; Dim < Dims; ++Dim) { + Product *= V.get_offset()[Dim]; + } + return (Product != 0); } +#endif //__SYCL_ID_QUERIES_FIT_IN_INT__ + +template +void checkValueRange([[maybe_unused]] const sycl::range &V) { +#if __SYCL_ID_QUERIES_FIT_IN_INT__ + if constexpr (Dims == 1) { + // For 1D range, just check the value against MAX_INT. + checkValueRangeImpl(V[0]); + } else if constexpr (Dims == 2) { + // For 2D range, check if computing the linear range overflows. + checkMulOverflow(V[0], V[1]); + } else if constexpr (Dims == 3) { + // For 3D range, check if computing the linear range overflows. + checkMulOverflow(V[0], V[1], V[2]); + } #endif +} -template -typename std::enable_if_t> || - std::is_same_v>> -checkValueRange([[maybe_unused]] const T &V) { +template +void checkValueRange([[maybe_unused]] const sycl::id &V) { #if __SYCL_ID_QUERIES_FIT_IN_INT__ - for (size_t Dim = 0; Dim < Dims; ++Dim) - checkValueRangeImpl(V[Dim]); - - { - unsigned long long Product = 1; - for (size_t Dim = 0; Dim < Dims; ++Dim) { - Product *= V[Dim]; - // check value now to prevent product overflow in the end - checkValueRangeImpl(Product); - } + // An id cannot be linearized without a range, so check each component. + for (int Dim = 0; Dim < Dims; ++Dim) { + checkValueRangeImpl(V[Dim]); } #endif } @@ -87,21 +129,23 @@ void checkValueRange([[maybe_unused]] const range &R, for (size_t Dim = 0; Dim < Dims; ++Dim) { unsigned long long Sum = R[Dim] + O[Dim]; - - checkValueRangeImpl>(Sum); + checkValueRangeImpl(Sum); } #endif } -template -typename std::enable_if_t>> -checkValueRange([[maybe_unused]] const T &V) { +template +void checkValueRange([[maybe_unused]] const sycl::nd_range &V) { #if __SYCL_ID_QUERIES_FIT_IN_INT__ - checkValueRange(V.get_global_range()); - checkValueRange(V.get_local_range()); - checkValueRange(V.get_offset()); - - checkValueRange(V.get_global_range(), V.get_offset()); + // In an ND-range, we only need to check the global linear size, because: + // - The linear size must be greater than any of the dimensions. + // - Each dimension of the global range is larger than the local range. + // TODO: Remove this branch when offsets are removed. + if (hasNonZeroOffset(V)) /*[[unlikely]]*/ { + checkValueRange(V.get_global_range(), V.get_offset()); + } else { + checkValueRange(V.get_global_range()); + } #endif } diff --git a/sycl/test-e2e/Basic/range_offset_fit_in_int.cpp b/sycl/test-e2e/Basic/range_offset_fit_in_int.cpp index c58b6460918c..95c9fa38dd59 100644 --- a/sycl/test-e2e/Basic/range_offset_fit_in_int.cpp +++ b/sycl/test-e2e/Basic/range_offset_fit_in_int.cpp @@ -8,21 +8,17 @@ namespace S = sycl; -void checkRangeException(S::exception &E) { - constexpr char Msg[] = "Provided range is out of integer limits. " - "Pass `-fno-sycl-id-queries-fit-in-int' to " - "disable range check."; +constexpr char Msg[] = "Provided range and/or offset does not fit in int. " + "Pass `-fno-sycl-id-queries-fit-in-int' to " + "remove this limit."; +void checkRangeException(S::exception &E) { std::cerr << E.what() << std::endl; assert(std::string(E.what()).find(Msg) == 0 && "Unexpected message"); } void checkOffsetException(S::exception &E) { - constexpr char Msg[] = "Provided offset is out of integer limits. " - "Pass `-fno-sycl-id-queries-fit-in-int' to " - "disable offset check."; - std::cerr << E.what() << std::endl; assert(std::string(E.what()).find(Msg) == 0 && "Unexpected message"); @@ -48,8 +44,6 @@ void test() { S::id<2> OffsetInLimits_Large{(OutOfLimitsSize / 4) * 3, 1}; S::nd_range<2> NDRange_ROL_LIL_OIL{RangeOutOfLimits, RangeInLimits, OffsetInLimits}; - S::nd_range<2> NDRange_RIL_LOL_OIL{RangeInLimits, RangeOutOfLimits, - OffsetInLimits}; S::nd_range<2> NDRange_RIL_LIL_OOL{RangeInLimits, RangeInLimits, OffsetOutOfLimits}; S::nd_range<2> NDRange_RIL_LIL_OIL(RangeInLimits, RangeInLimits, @@ -184,22 +178,6 @@ void test() { assert(false && "Unexpected exception catched"); } - // small offset, local range is out of limits - try { - Queue.submit([&](S::handler &CGH) { - auto Acc = Buf.get_access(CGH); - - CGH.parallel_for( - NDRange_RIL_LOL_OIL, [Acc](S::nd_item<2> Id) { Acc[0] += 1; }); - }); - - assert(false && "Exception expected"); - } catch (S::exception &E) { - checkRangeException(E); - } catch (...) { - assert(false && "Unexpected exception catched"); - } - // large offset, ranges are in limits try { Queue.submit([&](S::handler &CGH) {