diff --git a/src/ATen/native/xpu/sycl/DistributionTemplates.h b/src/ATen/native/xpu/sycl/DistributionTemplates.h index 0a63278c2..c0a6bba2d 100644 --- a/src/ATen/native/xpu/sycl/DistributionTemplates.h +++ b/src/ATen/native/xpu/sycl/DistributionTemplates.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -23,50 +24,6 @@ using namespace at::xpu; const uint32_t rand4_engine_calls = 4; -struct PhiloxState { - PhiloxState() = default; - // Called if graph capture is not underway - PhiloxState(uint64_t seed, uint64_t offset) { - seed_ = seed; - offset_.val = offset; - } - // Called if graph capture is underway - PhiloxState( - uint64_t seed, - int64_t* offset_extragraph, - uint32_t offset_intragraph) { - seed_ = seed; - offset_.ptr = offset_extragraph; - offset_intragraph_ = offset_intragraph; - captured_ = true; - } - - union Payload { - uint64_t val; - int64_t* ptr; - }; - - uint64_t seed_ = 0; - Payload offset_; - uint32_t offset_intragraph_ = 0; - bool captured_ = false; -}; - -inline std::tuple philox_unpack(PhiloxState arg) { - if (arg.captured_) { - // static_cast avoids "warning: invalid narrowing conversion from "long" to - // "unsigned long". - // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire - // kernel. For most threads' reads it will hit in cache, so it shouldn't - // hurt performance. - return std::make_tuple( - arg.seed_, - static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); - } else { - return std::make_tuple(arg.seed_, arg.offset_.val); - } -} - template inline std::tuple calc_execution_policy( int64_t total_elements) { @@ -96,7 +53,7 @@ struct DistributionElementwiseKernelFunctor { int num_groups = item.get_group_range(0); int idx = item.get_global_linear_id(); - auto seeds = philox_unpack(philox_args_); + auto seeds = at::xpu::philox::unpack(philox_args_); randStatePhilox4_32_10_t state; rand_init(std::get<0>(seeds), idx, std::get<1>(seeds), &state); @@ -125,15 +82,13 @@ struct DistributionElementwiseKernelFunctor { } DistributionElementwiseKernelFunctor( int64_t numel, - std::pair rng_engine_inputs, + PhiloxXpuState rng_engine_inputs, dist_t dist_func, transform_t transform_func, char* out_data, offset_calc_t offset_calc) : numel_(numel), - philox_args_(PhiloxState( - std::get<0>(rng_engine_inputs), - std::get<1>(rng_engine_inputs))), + philox_args_(rng_engine_inputs), dist_func_(dist_func), transform_func_(transform_func), out_data_(out_data), @@ -141,7 +96,7 @@ struct DistributionElementwiseKernelFunctor { private: int64_t numel_; - PhiloxState philox_args_; + PhiloxXpuState philox_args_; dist_t dist_func_; transform_t transform_func_; char* out_data_; @@ -171,11 +126,11 @@ void distribution_nullary_kernel( auto num_groups = std::get<1>(execution_policy); auto group_size = std::get<2>(execution_policy); - std::pair rng_engine_inputs; + PhiloxXpuState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(counter_offset); + rng_engine_inputs = gen->philox_xpu_state(counter_offset); } if (!iter.can_use_32bit_indexing()) { @@ -234,7 +189,7 @@ struct DistributionUnaryElementwiseKernelFunctor { int global_size = item.get_global_range(0); int global_idx = item.get_group(0) * group_size + item.get_local_id(0); - auto seeds = philox_unpack(philox_args_); + auto seeds = at::xpu::philox::unpack(philox_args_); randStatePhilox4_32_10_t state; rand_init(std::get<0>(seeds), global_idx, std::get<1>(seeds), &state); @@ -247,7 +202,7 @@ struct DistributionUnaryElementwiseKernelFunctor { DistributionUnaryElementwiseKernelFunctor( int numel, const func_t f, - PhiloxState philox_args, + PhiloxXpuState philox_args, scalar1_t* output_data, const scalar2_t* input_data, inp_offset_calc_t input_offset_calculator, @@ -263,7 +218,7 @@ struct DistributionUnaryElementwiseKernelFunctor { private: int numel_; const func_t f_; - PhiloxState philox_args_; + PhiloxXpuState philox_args_; scalar1_t* output_data_; const scalar2_t* input_data_; inp_offset_calc_t inp_calc_; @@ -273,7 +228,7 @@ struct DistributionUnaryElementwiseKernelFunctor { template void distribution_unary_kernel( TensorIterator& iter, - PhiloxState philox_args, + PhiloxXpuState philox_args, func_t f) { if (!iter.can_use_32bit_indexing()) { for (auto& sub_iter : iter.with_32bit_indexing()) { @@ -340,7 +295,7 @@ struct DistributionBinaryElementwiseKernelFunctor { int global_size = item.get_global_range(0); int global_idx = item.get_group(0) * group_size + item.get_local_id(0); - auto seeds = philox_unpack(philox_args_); + auto seeds = at::xpu::philox::unpack(philox_args_); randStatePhilox4_32_10_t state; rand_init(std::get<0>(seeds), global_idx, std::get<1>(seeds), &state); @@ -356,7 +311,7 @@ struct DistributionBinaryElementwiseKernelFunctor { DistributionBinaryElementwiseKernelFunctor( int numel, func_t f, - PhiloxState philox_args, + PhiloxXpuState philox_args, output_t* output_data, const input_t_1* input_data_1, const input_t_2* input_data_2, @@ -374,7 +329,7 @@ struct DistributionBinaryElementwiseKernelFunctor { private: int64_t numel_; func_t f_; - PhiloxState philox_args_; + PhiloxXpuState philox_args_; output_t* out_data_; const input_t_1* inp_data_1_; const input_t_2* inp_data_2_; @@ -385,7 +340,7 @@ struct DistributionBinaryElementwiseKernelFunctor { template void distribution_binary_kernel( TensorIteratorBase& iter, - PhiloxState philox_args, + PhiloxXpuState philox_args, const func_t& f) { static_assert( std::is_same< @@ -762,7 +717,7 @@ struct BernoulliTensorApplyFunctor { const prob_t& p2, const prob_t& p3, const prob_t& p4) const { - auto seeds = philox_unpack(philox_args_); + auto seeds = at::xpu::philox::unpack(philox_args_); randStatePhilox4_32_10_t state; rand_init( std::get<0>(seeds), @@ -792,20 +747,18 @@ struct BernoulliTensorApplyFunctor { } } } - BernoulliTensorApplyFunctor(std::pair rng_engine_inputs) - : philox_args_( - std::get<0>(rng_engine_inputs), - std::get<1>(rng_engine_inputs)) {} + BernoulliTensorApplyFunctor(PhiloxXpuState rng_engine_inputs) + : philox_args_(rng_engine_inputs) {} private: - PhiloxState philox_args_; + PhiloxXpuState philox_args_; }; template void bernoulli_tensor_kernel( TensorBase& ret, TensorBase& p, - std::pair rng_engine_inputs) { + PhiloxXpuState rng_engine_inputs) { auto functor = BernoulliTensorApplyFunctor(rng_engine_inputs); // The template argument `4` below indicates that we want to operate on four @@ -820,11 +773,11 @@ void bernoulli_tensor_kernel( template void bernoulli_kernel(const TensorBase& self, const TensorBase& p_, RNG gen) { - std::pair rng_engine_inputs; + PhiloxXpuState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(10); + rng_engine_inputs = gen->philox_xpu_state(10); } TORCH_CHECK( at::isFloatingType(p_.scalar_type()), diff --git a/src/ATen/native/xpu/sycl/Distributions.cpp b/src/ATen/native/xpu/sycl/Distributions.cpp index 2c9239ef3..2506777ec 100644 --- a/src/ATen/native/xpu/sycl/Distributions.cpp +++ b/src/ATen/native/xpu/sycl/Distributions.cpp @@ -17,7 +17,7 @@ struct PoissonTensorApplyFunctor { SYCL_KERNEL_ASSERT( lambda >= 0 && "invalid Poisson rate, expected rate to be non-negative"); - auto seeds = philox_unpack(philox_args_); + auto seeds = at::xpu::philox::unpack(philox_args_); randStatePhilox4_32_10_t state; rand_init( std::get<0>(seeds), @@ -26,20 +26,18 @@ struct PoissonTensorApplyFunctor { &state); ret_val = static_cast(rand_poisson(&state, lambda)); } - PoissonTensorApplyFunctor(std::pair rng_engine_inputs) - : philox_args_( - std::get<0>(rng_engine_inputs), - std::get<1>(rng_engine_inputs)) {} + PoissonTensorApplyFunctor(PhiloxXpuState rng_engine_inputs) + : philox_args_(rng_engine_inputs) {} private: - PhiloxState philox_args_; + PhiloxXpuState philox_args_; }; template void poisson_kernel( const at::TensorBase& ret, const at::TensorBase& lambda, - std::pair rng_engine_inputs) { + PhiloxXpuState rng_engine_inputs) { auto functor = PoissonTensorApplyFunctor(rng_engine_inputs); at::native::xpu::tensor_apply2< scalar_t, @@ -55,11 +53,11 @@ void launch_poisson_kernel( const TensorBase& ret, const TensorBase& lambda, at::XPUGeneratorImpl* gen) { - std::pair rng_engine_inputs; + PhiloxXpuState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(20); + rng_engine_inputs = gen->philox_xpu_state(20); } AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, @@ -101,21 +99,19 @@ struct BinomialFunctor { }; template -void binomial_kernel(TensorIteratorBase& iter, PhiloxState philox_args) { +void binomial_kernel(TensorIteratorBase& iter, PhiloxXpuState philox_args) { using accscalar_t = at::acc_type_device; BinomialFunctor f; at::native::xpu::distribution_binary_kernel(iter, philox_args, f); } void launch_binomial_kernel(TensorIteratorBase& iter, XPUGeneratorImpl* gen) { - std::pair engine_inputs; + PhiloxXpuState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - engine_inputs = gen->philox_engine_inputs(42); + rng_engine_inputs = gen->philox_xpu_state(42); } - PhiloxState rng_engine_inputs( - std::get<0>(engine_inputs), std::get<1>(engine_inputs)); AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, @@ -130,7 +126,7 @@ struct GammaTensorApplyFunctor { sycl::nd_item<1> item, scalar_t& ret_val, const scalar_t& alpha) const { - auto seeds = philox_unpack(philox_args_); + auto seeds = at::xpu::philox::unpack(philox_args_); randStatePhilox4_32_10_t state; rand_init( std::get<0>(seeds), @@ -155,18 +151,18 @@ struct GammaTensorApplyFunctor { ret_val = (min_value > sample) ? min_value : sample; } - GammaTensorApplyFunctor(PhiloxState philox_args) + GammaTensorApplyFunctor(PhiloxXpuState philox_args) : philox_args_(philox_args) {} private: - PhiloxState philox_args_; + PhiloxXpuState philox_args_; }; template void gamma_kernel( const at::TensorBase& ret, const at::TensorBase& alpha, - PhiloxState philox_args) { + PhiloxXpuState philox_args) { using accscalar_t = at::acc_type_device; GammaTensorApplyFunctor functor(philox_args); at::native::xpu::tensor_apply2< @@ -183,7 +179,7 @@ void launch_gamma_kernel( Tensor& ret, const Tensor& alpha, XPUGeneratorImpl* gen) { - std::pair engine_inputs; + PhiloxXpuState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); @@ -191,10 +187,9 @@ void launch_gamma_kernel( // This seed was chosen to ensure consistent random number generation // behavior for this specific kernel. Modify with caution as it affects // reproducibility of results. - engine_inputs = gen->philox_engine_inputs(10); + rng_engine_inputs = gen->philox_xpu_state(10); } - PhiloxState rng_engine_inputs( - std::get<0>(engine_inputs), std::get<1>(engine_inputs)); + AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, diff --git a/src/ATen/native/xpu/sycl/Dropout.cpp b/src/ATen/native/xpu/sycl/Dropout.cpp index acb3faee0..1641e074a 100644 --- a/src/ATen/native/xpu/sycl/Dropout.cpp +++ b/src/ATen/native/xpu/sycl/Dropout.cpp @@ -40,7 +40,7 @@ struct FusedDropoutVecFunctor { using LoadT = memory::aligned_vector; using MaskLoadT = memory::aligned_vector; - auto seeds = philox_unpack(philox_args_); + auto seeds = at::xpu::philox::unpack(philox_args_); IndexType idx = item.get_global_linear_id(); randStatePhilox4_32_10_t state; rand_init(std::get<0>(seeds), idx, std::get<1>(seeds), &state); @@ -112,7 +112,7 @@ struct FusedDropoutVecFunctor { TensorInfo c, IndexType total_elements, accscalar_t p, - PhiloxState philox_args) + PhiloxXpuState philox_args) : a_(a), b_(b), c_(c), @@ -126,7 +126,7 @@ struct FusedDropoutVecFunctor { TensorInfo c_; IndexType total_elements_; accscalar_t p_; - PhiloxState philox_args_; + PhiloxXpuState philox_args_; }; template < @@ -139,7 +139,7 @@ template < struct FusedDropoutUnrollFunctor { void operator()(sycl::nd_item<1> item) const { constexpr int UNROLL = 4; - auto seeds = philox_unpack(philox_args_); + auto seeds = at::xpu::philox::unpack(philox_args_); IndexType idx = item.get_global_linear_id(); randStatePhilox4_32_10_t state; rand_init(std::get<0>(seeds), idx, std::get<1>(seeds), &state); @@ -187,7 +187,7 @@ struct FusedDropoutUnrollFunctor { TensorInfo c, IndexType total_elements, accscalar_t p, - PhiloxState philox_args) + PhiloxXpuState philox_args) : a_(a), b_(b), c_(c), @@ -201,7 +201,7 @@ struct FusedDropoutUnrollFunctor { TensorInfo c_; IndexType total_elements_; accscalar_t p_; - PhiloxState philox_args_; + PhiloxXpuState philox_args_; }; template @@ -267,7 +267,7 @@ inline void launcher( Tensor& mask, double p, const int64_t nelem, - const PhiloxState rng_engine_inputs, + const PhiloxXpuState rng_engine_inputs, uint32_t num_groups, uint32_t group_size) { AT_DISPATCH_FLOATING_TYPES_AND2( @@ -390,14 +390,12 @@ std::tuple dropout( std::tie(counter_offset, num_groups, group_size) = calc_execution_policy(nelem); - std::pair rng_engine_inputs_; + PhiloxXpuState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs_ = gen->philox_engine_inputs(counter_offset); + rng_engine_inputs = gen->philox_xpu_state(counter_offset); } - PhiloxState rng_engine_inputs( - std::get<0>(rng_engine_inputs_), std::get<1>(rng_engine_inputs_)); if (canUse32BitIndexMath(self)) { launcher( diff --git a/src/ATen/native/xpu/sycl/MultinomialKernel.cpp b/src/ATen/native/xpu/sycl/MultinomialKernel.cpp index a5590d5cb..09de352c6 100644 --- a/src/ATen/native/xpu/sycl/MultinomialKernel.cpp +++ b/src/ATen/native/xpu/sycl/MultinomialKernel.cpp @@ -153,7 +153,7 @@ inline int binarySearchForMultinomial( template inline void sampleMultinomialWithReplacement( item_t& item, - PhiloxState philox_args, + PhiloxXpuState philox_args, int totalSamples, int64_t* dest, int64_t distributions, @@ -171,7 +171,7 @@ inline void sampleMultinomialWithReplacement( // search due to divergence. It seems possible to compute multiple // values and limit divergence though later on. - auto seeds = philox_unpack(philox_args); + auto seeds = at::xpu::philox::unpack(philox_args); // global index formula for 2D grid of 1D group int idx = group_idx_y * group_range_x * thread_range + @@ -217,7 +217,7 @@ struct MultinomialWithReplacementKernelImplFunctor { normDist_ptr); } MultinomialWithReplacementKernelImplFunctor( - PhiloxState rng_engine_inputs_, + PhiloxXpuState rng_engine_inputs_, const int64_t n_sample_, int64_t* result_ptr_, int64_t numDist_, @@ -233,7 +233,7 @@ struct MultinomialWithReplacementKernelImplFunctor { normDist_ptr(normDist_ptr_) {} private: - PhiloxState rng_engine_inputs; + PhiloxXpuState rng_engine_inputs; const int64_t n_sample; int64_t* result_ptr; int64_t numDist; @@ -509,16 +509,13 @@ void multinomial_kernel( int group_range_y = numDist; int group_range_x = (n_sample - 1) / group_size + 1; - std::pair rng_engine_inputs_; + PhiloxXpuState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto offset = ((numDist - 1) / group_range_y + 1) * 4; - rng_engine_inputs_ = gen->philox_engine_inputs(offset); + rng_engine_inputs = gen->philox_xpu_state(offset); } - auto rng_engine_inputs = PhiloxState( - std::get<0>(rng_engine_inputs_), std::get<1>(rng_engine_inputs_)); - // Sample with replacement auto result_ptr = result.data_ptr(); auto prefixSum_ptr = prefixSum.data_ptr(); diff --git a/src/ATen/native/xpu/sycl/RandpermKernel.cpp b/src/ATen/native/xpu/sycl/RandpermKernel.cpp index 3345a241d..7974c1288 100644 --- a/src/ATen/native/xpu/sycl/RandpermKernel.cpp +++ b/src/ATen/native/xpu/sycl/RandpermKernel.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -39,7 +38,7 @@ struct HandleDuplicateKeysKernelFunctor { // do random permutation inside each island. auto data = data_; data += tid; - auto seeds = philox_unpack(philox_args_); + auto seeds = at::xpu::philox::unpack(philox_args_); randStatePhilox4_32_10_t state; rand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state); @@ -57,7 +56,7 @@ struct HandleDuplicateKeysKernelFunctor { scalar_t* data, T mask, int n, - PhiloxState philox_args) + PhiloxXpuState philox_args) : keys_(keys), data_(data), mask_(mask), @@ -69,7 +68,7 @@ struct HandleDuplicateKeysKernelFunctor { scalar_t* data_; const T mask_; const int n_; - const PhiloxState philox_args_; + const PhiloxXpuState philox_args_; }; // See note [Algorithm of randperm] @@ -85,14 +84,12 @@ void randperm_handle_duplicate_keys( int64_t counter_offset = n; - std::pair rng_engine_inputs_; + PhiloxXpuState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs_ = gen->philox_engine_inputs(counter_offset); + rng_engine_inputs = gen->philox_xpu_state(counter_offset); } - PhiloxState rng_engine_inputs( - std::get<0>(rng_engine_inputs_), std::get<1>(rng_engine_inputs_)); T mask = static_cast((1UL << bits) - 1); HandleDuplicateKeysKernelFunctor kfn(keys, data, mask, n, rng_engine_inputs); diff --git a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp index 7f6f33805..3fb4a3162 100644 --- a/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp +++ b/src/ATen/native/xpu/sycl/RreluWithNoiseKernels.cpp @@ -12,7 +12,7 @@ namespace at::native::xpu { template struct RreluWithNoiseKernelFunctor { void operator()(sycl::nd_item<1> item) const { - auto seeds = philox_unpack(philox_args_); + auto seeds = at::xpu::philox::unpack(philox_args_); int group_size = item.get_local_range(0); int num_groups = item.get_group_range(0); int idx = item.get_global_linear_id(); @@ -53,7 +53,7 @@ struct RreluWithNoiseKernelFunctor { } RreluWithNoiseKernelFunctor( int numel, - std::pair rng_engine_inputs, + PhiloxXpuState rng_engine_inputs, scalar_t* output, const scalar_t* input, scalar_t* noise, @@ -61,9 +61,7 @@ struct RreluWithNoiseKernelFunctor { double upper, transform_t random_func) : numel_(numel), - philox_args_(PhiloxState( - std::get<0>(rng_engine_inputs), - std::get<1>(rng_engine_inputs))), + philox_args_(rng_engine_inputs), output_(output), input_(input), noise_(noise), @@ -73,7 +71,7 @@ struct RreluWithNoiseKernelFunctor { private: int numel_; - PhiloxState philox_args_; + PhiloxXpuState philox_args_; scalar_t* output_; const scalar_t* input_; scalar_t* noise_; @@ -103,11 +101,11 @@ inline void _rrelu_with_noise_xpu_train( auto gen = get_generator_or_default( generator, at::xpu::detail::getDefaultXPUGenerator()); - std::pair rng_engine_inputs; + PhiloxXpuState rng_engine_inputs; { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_engine_inputs(counter_offset); + rng_engine_inputs = gen->philox_xpu_state(counter_offset); } const scalar_t* input_data = input.const_data_ptr();