Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 33 additions & 59 deletions src/ATen/native/xpu/sycl/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ namespace at::native::xpu {
constexpr int CAT_ARRAY_BATCH_SIZE = 1024;

// Maximum parallel dimension to supporte
constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 5;

#define CAT_ARRAY_MAX_INPUT_DIMS 5
// Similar to any other IndexToOffset calculation for copying along a given
// dimension.
template <typename IndexType, int Dims>
Expand Down Expand Up @@ -66,57 +65,37 @@ struct OutputTensorSizeStride {
IndexType outputStride[MaxDims];
};

template <
typename Tout,
typename underlying_out_t,
typename Tin,
typename underlying_in_t,
typename IndexType,
int Dims>
struct CatArrayBatchedCopyKernelFunctor {
void operator()(sycl::nd_item<2> item) const {
IndexType tid = item.get_global_id(1);
IndexType in = item.get_group(0);

IndexType nElements = inputs[in].nElements;
template <typename Tout, typename underlying_out_t, typename Tin,
typename underlying_in_t, typename IndexType, int Dims>
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
void cat_array_batched_copy_kernel(
Tout *output, CatArrInputTensor<Tin, IndexType> *inputs,
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim, int dimStride) {
auto item = syclext::this_work_item::get_nd_item<2>();
IndexType tid = item.get_global_id(1);
IndexType in = item.get_group(0);

if (tid >= nElements)
return;
IndexType nElements = inputs[in].nElements;

Tin* data = inputs[in].input;
IndexType offset = inputs[in].offset;
IndexType dimSize = inputs[in].dimSize;
IndexType dataOffset = offset * dimStride;
if (tid >= nElements)
return;

IndexType stride = item.get_global_range(1);
Tin *data = inputs[in].input;
IndexType offset = inputs[in].offset;
IndexType dimSize = inputs[in].dimSize;
IndexType dataOffset = offset * dimStride;

while (tid < nElements) {
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.outputSize, os.outputStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[tid];
tid += stride;
}
}
IndexType stride = item.get_global_range(1);

CatArrayBatchedCopyKernelFunctor(
Tout* output_,
CatArrInputTensor<Tin, IndexType>* inputs_,
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os_,
const int concatDim_,
IndexType dimStride_)
: output(output_),
inputs(inputs_),
os(os_),
concatDim(concatDim_),
dimStride(dimStride_) {}

private:
Tout* output;
CatArrInputTensor<Tin, IndexType>* inputs;
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os;
const int concatDim;
IndexType dimStride;
};
while (tid < nElements) {
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.outputSize, os.outputStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[tid];
tid += stride;
}
}

/**
* Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a
Expand Down Expand Up @@ -146,19 +125,11 @@ void CatArrayBatchedCopy(
const int concatDim,
IndexType dimStride,
int batchCounter) {
CatArrayBatchedCopyKernelFunctor<
Tout,
underlying_out_t,
Tin,
underlying_in_t,
IndexType,
Dims>
kfn(output, inputs, os, concatDim, dimStride);

// Get grid where x dim fills half gpu and y dim is number of tensors.
// This will have cating two tensors fill the entire grid, but prevent
// many threads from needlessly load meta data if their sizes is small.
int64_t numWI = syclMaxWorkGroupSize(kfn);
int64_t numWI = syclMaxWorkGroupSize<cat_array_batched_copy_kernel<
Tout, underlying_out_t, Tin, underlying_in_t, IndexType, Dims>>();

// We set limited numWG to prevent over schedule.
// numWG = 512 EUs * 8 threads * SIMD lanes 32 / max_compute_units
Expand All @@ -175,8 +146,11 @@ void CatArrayBatchedCopy(
sycl::range<2> global_range(batchCounter, numWG * numWI);
sycl::range<2> local_range(1, numWI);
auto& q = getCurrentSYCLQueue();

sycl_kernel_submit(global_range, local_range, q, kfn);
sycl_kernel_submit<
cat_array_batched_copy_kernel<Tout, underlying_out_t, Tin,
underlying_in_t, IndexType, Dims>,
2>(global_range, local_range, q, 0, output, inputs, os, concatDim,
dimStride);
}

template <
Expand Down