Skip to content

Commit ea7e8a5

Browse files
authored
amrex::fillAsync (#3076)
This new function can be used to fill the elements in a vector type container (e.g., Gpu::DeviceVector). If the element type is a struct with several arithmetic types (e.g., GpuArray<Real,10>), the usual ParallelFor does not have good performance because of the memory access pattern. The new fillAsync function will use shared memory in that case to improve the performance.
1 parent 9c76f9a commit ea7e8a5

File tree

1 file changed

+99
-1
lines changed

1 file changed

+99
-1
lines changed

Src/Base/AMReX_GpuContainers.H

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,105 @@ namespace Gpu {
402402
Gpu::streamSynchronize();
403403
}
404404

405-
}}
405+
/**
406+
* \brief Fill the elements in the given range using the given
407+
* calllable.
408+
*
409+
* This function is asynchronous for GPU builds.
410+
*
411+
* \tparam IT the iterator type
412+
* \tparam F the callable type
413+
*
414+
* \param first the inclusive first in the range [first, last)
415+
* \param last the exclusive last in the range [first, last)
416+
* \param f the callable with the function signature of void(T&, Long),
417+
* where T is the element type and the Long parameter is the
418+
* index for the element to be filled.
419+
*/
420+
template <typename IT, typename F,
421+
typename T = typename std::iterator_traits<IT>::value_type,
422+
std::enable_if_t<(sizeof(T) <= 36*8) && // so there is enough shared memory
423+
std::is_trivially_copyable_v<T> &&
424+
// std::is_invocable_v<F, T&, Long>, // HIP does not like this.
425+
!std::is_convertible_v<std::decay_t<F>,T>, // So we use this instead.
426+
int> FOO = 0>
427+
void fillAsync (IT first, IT last, F&& f) noexcept
428+
{
429+
auto N = static_cast<Long>(std::distance(first, last));
430+
if (N <= 0) return;
431+
auto p = &(*first);
432+
#ifndef AMREX_USE_GPU
433+
for (Long i = 0; i < N; ++i) {
434+
f(p[i], i);
435+
}
436+
#else
437+
// No need to use shared memory if the type is small.
438+
// May not have enough shared memory if the type is too big.
439+
// Cannot use shared memory, if the type is not trivially copable.
440+
if constexpr ((sizeof(T) <= 8)
441+
|| (sizeof(T) > 36*8)
442+
|| ! std::is_trivially_copyable<T>()) {
443+
amrex::ParallelFor(N, [=] AMREX_GPU_DEVICE (Long i) noexcept
444+
{
445+
f(p[i], i);
446+
});
447+
} else {
448+
static_assert(sizeof(T) % sizeof(unsigned int) == 0);
449+
using U = std::conditional_t<sizeof(T) % sizeof(unsigned long long) == 0,
450+
unsigned long long, unsigned int>;
451+
constexpr Long nU = sizeof(T) / sizeof(U);
452+
auto pu = reinterpret_cast<U*>(p);
453+
int nthreads_per_block = (sizeof(T) <= 64) ? 256 : 128;
454+
int nblocks = static_cast<int>((N+nthreads_per_block-1)/nthreads_per_block);
455+
std::size_t shared_mem_bytes = nthreads_per_block * sizeof(T);
456+
#ifdef AMREX_USE_DPCPP
457+
amrex::launch(nblocks, nthreads_per_block, shared_mem_bytes, Gpu::gpuStream(),
458+
[=] AMREX_GPU_DEVICE (Gpu::Handler const& handler) noexcept
459+
{
460+
Long i = handler.globalIdx();
461+
Long blockDimx = handler.blockDim();
462+
Long threadIdxx = handler.threadIdx();
463+
Long blockIdxx = handler.blockIdx();
464+
auto const shared_U = (U*)handler.sharedMemory();
465+
auto const shared_T = (T*)shared_U;
466+
if (i < N) {
467+
auto ga = new(shared_T+threadIdxx) T;
468+
f(*ga, i);
469+
}
470+
handler.sharedBarrier();
471+
for (Long m = threadIdxx,
472+
mend = nU * amrex::min(blockDimx, N-blockDimx*blockIdxx);
473+
m < mend; m += blockDimx) {
474+
pu[blockDimx*blockIdxx*nU+m] = shared_U[m];
475+
}
476+
});
477+
#else
478+
amrex::launch(nblocks, nthreads_per_block, shared_mem_bytes, Gpu::gpuStream(),
479+
[=] AMREX_GPU_DEVICE () noexcept
480+
{
481+
Long blockDimx = blockDim.x;
482+
Long threadIdxx = threadIdx.x;
483+
Long blockIdxx = blockIdx.x;
484+
Long i = blockDimx*blockIdxx + threadIdxx;
485+
Gpu::SharedMemory<U> gsm;
486+
auto const shared_U = gsm.dataPtr();
487+
auto const shared_T = (T*)shared_U;
488+
if (i < N) {
489+
auto ga = new(shared_T+threadIdxx) T;
490+
f(*ga, i);
491+
}
492+
__syncthreads();
493+
for (Long m = threadIdxx,
494+
mend = nU * amrex::min(blockDimx, N-blockDimx*blockIdxx);
495+
m < mend; m += blockDimx) {
496+
pu[blockDimx*blockIdxx*nU+m] = shared_U[m];
497+
}
498+
});
499+
#endif
500+
}
501+
#endif
502+
}
406503

504+
}}
407505

408506
#endif

0 commit comments

Comments
 (0)