Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 22 additions & 69 deletions src/ATen/native/xpu/sycl/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ATen/native/xpu/sycl/Philox4x32.h>
#include <ATen/native/xpu/sycl/TensorApplyUtils.h>
#include <ATen/ops/empty.h>
#include <ATen/xpu/PhiloxXpuState.h>
#include <comm/DeviceProperties.h>
#include <comm/Runtime.h>

Expand All @@ -23,50 +24,6 @@ using namespace at::xpu;

const uint32_t rand4_engine_calls = 4;

struct PhiloxState {
PhiloxState() = default;
// Called if graph capture is not underway
PhiloxState(uint64_t seed, uint64_t offset) {
seed_ = seed;
offset_.val = offset;
}
// Called if graph capture is underway
PhiloxState(
uint64_t seed,
int64_t* offset_extragraph,
uint32_t offset_intragraph) {
seed_ = seed;
offset_.ptr = offset_extragraph;
offset_intragraph_ = offset_intragraph;
captured_ = true;
}

union Payload {
uint64_t val;
int64_t* ptr;
};

uint64_t seed_ = 0;
Payload offset_;
uint32_t offset_intragraph_ = 0;
bool captured_ = false;
};

inline std::tuple<uint64_t, uint64_t> philox_unpack(PhiloxState arg) {
if (arg.captured_) {
// static_cast avoids "warning: invalid narrowing conversion from "long" to
// "unsigned long".
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire
// kernel. For most threads' reads it will hit in cache, so it shouldn't
// hurt performance.
return std::make_tuple(
arg.seed_,
static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
} else {
return std::make_tuple(arg.seed_, arg.offset_.val);
}
}

template <uint32_t UNROLL = rand4_engine_calls>
inline std::tuple<uint64_t, uint32_t, uint32_t> calc_execution_policy(
int64_t total_elements) {
Expand Down Expand Up @@ -96,7 +53,7 @@ struct DistributionElementwiseKernelFunctor {
int num_groups = item.get_group_range(0);
int idx = item.get_global_linear_id();

auto seeds = philox_unpack(philox_args_);
auto seeds = at::xpu::philox::unpack(philox_args_);
randStatePhilox4_32_10_t state;
rand_init(std::get<0>(seeds), idx, std::get<1>(seeds), &state);

Expand Down Expand Up @@ -125,23 +82,21 @@ struct DistributionElementwiseKernelFunctor {
}
DistributionElementwiseKernelFunctor(
int64_t numel,
std::pair<uint64_t, uint64_t> rng_engine_inputs,
PhiloxXpuState rng_engine_inputs,
dist_t dist_func,
transform_t transform_func,
char* out_data,
offset_calc_t offset_calc)
: numel_(numel),
philox_args_(PhiloxState(
std::get<0>(rng_engine_inputs),
std::get<1>(rng_engine_inputs))),
philox_args_(rng_engine_inputs),
dist_func_(dist_func),
transform_func_(transform_func),
out_data_(out_data),
offset_calc_(offset_calc) {}

private:
int64_t numel_;
PhiloxState philox_args_;
PhiloxXpuState philox_args_;
dist_t dist_func_;
transform_t transform_func_;
char* out_data_;
Expand Down Expand Up @@ -171,11 +126,11 @@ void distribution_nullary_kernel(
auto num_groups = std::get<1>(execution_policy);
auto group_size = std::get<2>(execution_policy);

std::pair<uint64_t, uint64_t> rng_engine_inputs;
PhiloxXpuState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
rng_engine_inputs = gen->philox_xpu_state(counter_offset);
}

if (!iter.can_use_32bit_indexing()) {
Expand Down Expand Up @@ -234,7 +189,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
int global_size = item.get_global_range(0);
int global_idx = item.get_group(0) * group_size + item.get_local_id(0);

auto seeds = philox_unpack(philox_args_);
auto seeds = at::xpu::philox::unpack(philox_args_);
randStatePhilox4_32_10_t state;
rand_init(std::get<0>(seeds), global_idx, std::get<1>(seeds), &state);

Expand All @@ -247,7 +202,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
DistributionUnaryElementwiseKernelFunctor(
int numel,
const func_t f,
PhiloxState philox_args,
PhiloxXpuState philox_args,
scalar1_t* output_data,
const scalar2_t* input_data,
inp_offset_calc_t input_offset_calculator,
Expand All @@ -263,7 +218,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
private:
int numel_;
const func_t f_;
PhiloxState philox_args_;
PhiloxXpuState philox_args_;
scalar1_t* output_data_;
const scalar2_t* input_data_;
inp_offset_calc_t inp_calc_;
Expand All @@ -273,7 +228,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
template <typename scalar1_t, typename scalar2_t, typename func_t>
void distribution_unary_kernel(
TensorIterator& iter,
PhiloxState philox_args,
PhiloxXpuState philox_args,
func_t f) {
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
Expand Down Expand Up @@ -340,7 +295,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
int global_size = item.get_global_range(0);
int global_idx = item.get_group(0) * group_size + item.get_local_id(0);

auto seeds = philox_unpack(philox_args_);
auto seeds = at::xpu::philox::unpack(philox_args_);

randStatePhilox4_32_10_t state;
rand_init(std::get<0>(seeds), global_idx, std::get<1>(seeds), &state);
Expand All @@ -356,7 +311,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
DistributionBinaryElementwiseKernelFunctor(
int numel,
func_t f,
PhiloxState philox_args,
PhiloxXpuState philox_args,
output_t* output_data,
const input_t_1* input_data_1,
const input_t_2* input_data_2,
Expand All @@ -374,7 +329,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
private:
int64_t numel_;
func_t f_;
PhiloxState philox_args_;
PhiloxXpuState philox_args_;
output_t* out_data_;
const input_t_1* inp_data_1_;
const input_t_2* inp_data_2_;
Expand All @@ -385,7 +340,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
template <typename func_t>
void distribution_binary_kernel(
TensorIteratorBase& iter,
PhiloxState philox_args,
PhiloxXpuState philox_args,
const func_t& f) {
static_assert(
std::is_same<
Expand Down Expand Up @@ -762,7 +717,7 @@ struct BernoulliTensorApplyFunctor {
const prob_t& p2,
const prob_t& p3,
const prob_t& p4) const {
auto seeds = philox_unpack(philox_args_);
auto seeds = at::xpu::philox::unpack(philox_args_);
randStatePhilox4_32_10_t state;
rand_init(
std::get<0>(seeds),
Expand Down Expand Up @@ -792,20 +747,18 @@ struct BernoulliTensorApplyFunctor {
}
}
}
BernoulliTensorApplyFunctor(std::pair<uint64_t, uint64_t> rng_engine_inputs)
: philox_args_(
std::get<0>(rng_engine_inputs),
std::get<1>(rng_engine_inputs)) {}
BernoulliTensorApplyFunctor(PhiloxXpuState rng_engine_inputs)
: philox_args_(rng_engine_inputs) {}

private:
PhiloxState philox_args_;
PhiloxXpuState philox_args_;
};

template <typename scalar_t, typename prob_t>
void bernoulli_tensor_kernel(
TensorBase& ret,
TensorBase& p,
std::pair<uint64_t, uint64_t> rng_engine_inputs) {
PhiloxXpuState rng_engine_inputs) {
auto functor =
BernoulliTensorApplyFunctor<scalar_t, prob_t>(rng_engine_inputs);
// The template argument `4` below indicates that we want to operate on four
Expand All @@ -820,11 +773,11 @@ void bernoulli_tensor_kernel(

template <typename RNG>
void bernoulli_kernel(const TensorBase& self, const TensorBase& p_, RNG gen) {
std::pair<uint64_t, uint64_t> rng_engine_inputs;
PhiloxXpuState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(10);
rng_engine_inputs = gen->philox_xpu_state(10);
}
TORCH_CHECK(
at::isFloatingType(p_.scalar_type()),
Expand Down
39 changes: 17 additions & 22 deletions src/ATen/native/xpu/sycl/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct PoissonTensorApplyFunctor {
SYCL_KERNEL_ASSERT(
lambda >= 0 &&
"invalid Poisson rate, expected rate to be non-negative");
auto seeds = philox_unpack(philox_args_);
auto seeds = at::xpu::philox::unpack(philox_args_);
randStatePhilox4_32_10_t state;
rand_init(
std::get<0>(seeds),
Expand All @@ -26,20 +26,18 @@ struct PoissonTensorApplyFunctor {
&state);
ret_val = static_cast<scalar_t>(rand_poisson(&state, lambda));
}
PoissonTensorApplyFunctor(std::pair<uint64_t, uint64_t> rng_engine_inputs)
: philox_args_(
std::get<0>(rng_engine_inputs),
std::get<1>(rng_engine_inputs)) {}
PoissonTensorApplyFunctor(PhiloxXpuState rng_engine_inputs)
: philox_args_(rng_engine_inputs) {}

private:
PhiloxState philox_args_;
PhiloxXpuState philox_args_;
};

template <typename scalar_t>
void poisson_kernel(
const at::TensorBase& ret,
const at::TensorBase& lambda,
std::pair<uint64_t, uint64_t> rng_engine_inputs) {
PhiloxXpuState rng_engine_inputs) {
auto functor = PoissonTensorApplyFunctor<scalar_t>(rng_engine_inputs);
at::native::xpu::tensor_apply2<
scalar_t,
Expand All @@ -55,11 +53,11 @@ void launch_poisson_kernel(
const TensorBase& ret,
const TensorBase& lambda,
at::XPUGeneratorImpl* gen) {
std::pair<uint64_t, uint64_t> rng_engine_inputs;
PhiloxXpuState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(20);
rng_engine_inputs = gen->philox_xpu_state(20);
}
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
Expand Down Expand Up @@ -101,21 +99,19 @@ struct BinomialFunctor {
};

template <typename scalar_t>
void binomial_kernel(TensorIteratorBase& iter, PhiloxState philox_args) {
void binomial_kernel(TensorIteratorBase& iter, PhiloxXpuState philox_args) {
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;
BinomialFunctor<scalar_t, accscalar_t> f;
at::native::xpu::distribution_binary_kernel(iter, philox_args, f);
}

void launch_binomial_kernel(TensorIteratorBase& iter, XPUGeneratorImpl* gen) {
std::pair<uint64_t, uint64_t> engine_inputs;
PhiloxXpuState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
engine_inputs = gen->philox_engine_inputs(42);
rng_engine_inputs = gen->philox_xpu_state(42);
}
PhiloxState rng_engine_inputs(
std::get<0>(engine_inputs), std::get<1>(engine_inputs));
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
Expand All @@ -130,7 +126,7 @@ struct GammaTensorApplyFunctor {
sycl::nd_item<1> item,
scalar_t& ret_val,
const scalar_t& alpha) const {
auto seeds = philox_unpack(philox_args_);
auto seeds = at::xpu::philox::unpack(philox_args_);
randStatePhilox4_32_10_t state;
rand_init(
std::get<0>(seeds),
Expand All @@ -155,18 +151,18 @@ struct GammaTensorApplyFunctor {
ret_val = (min_value > sample) ? min_value : sample;
}

GammaTensorApplyFunctor(PhiloxState philox_args)
GammaTensorApplyFunctor(PhiloxXpuState philox_args)
: philox_args_(philox_args) {}

private:
PhiloxState philox_args_;
PhiloxXpuState philox_args_;
};

template <typename scalar_t>
void gamma_kernel(
const at::TensorBase& ret,
const at::TensorBase& alpha,
PhiloxState philox_args) {
PhiloxXpuState philox_args) {
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;
GammaTensorApplyFunctor<scalar_t, accscalar_t> functor(philox_args);
at::native::xpu::tensor_apply2<
Expand All @@ -183,18 +179,17 @@ void launch_gamma_kernel(
Tensor& ret,
const Tensor& alpha,
XPUGeneratorImpl* gen) {
std::pair<uint64_t, uint64_t> engine_inputs;
PhiloxXpuState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
// Using a seed value of 10 for the Philox random engine initialization.
// This seed was chosen to ensure consistent random number generation
// behavior for this specific kernel. Modify with caution as it affects
// reproducibility of results.
engine_inputs = gen->philox_engine_inputs(10);
rng_engine_inputs = gen->philox_xpu_state(10);
}
PhiloxState rng_engine_inputs(
std::get<0>(engine_inputs), std::get<1>(engine_inputs));

AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
Expand Down
Loading