Skip to content

Commit 0165b67

Browse files
authored
FFT OpenBC Solver: more optimization (AMReX-Codes#4232)
1 parent 294b6fe commit 0165b67

File tree

2 files changed

+72
-27
lines changed

2 files changed

+72
-27
lines changed

Src/FFT/AMReX_FFT_OpenBCSolver.H

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
151151
}
152152
}
153153
}
154+
155+
m_r2c.prepare_openbc();
154156
}
155157

156158
template <typename T>

Src/FFT/AMReX_FFT_R2C.H

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ public:
164164
template <typename F>
165165
void post_forward_doit (F const& post_forward);
166166

167+
void prepare_openbc ();
168+
167169
private:
168170

169171
static std::pair<Plan<T>,Plan<T>> make_c2c_plans (cMF& inout);
@@ -176,6 +178,8 @@ private:
176178
Plan<T> m_fft_bwd_y{};
177179
Plan<T> m_fft_fwd_z{};
178180
Plan<T> m_fft_bwd_z{};
181+
Plan<T> m_fft_fwd_x_half{};
182+
Plan<T> m_fft_bwd_x_half{};
179183

180184
// Comm meta-data. In the forward phase, we start with (x,y,z),
181185
// transpose to (y,x,z) and then (z,x,y). In the backward phase, we
@@ -394,6 +398,60 @@ R2C<T,D,S>::~R2C<T,D,S> ()
394398
m_fft_fwd_x.destroy();
395399
m_fft_fwd_y.destroy();
396400
m_fft_fwd_z.destroy();
401+
if (m_fft_bwd_x_half.plan != m_fft_fwd_x_half.plan) {
402+
m_fft_bwd_x_half.destroy();
403+
}
404+
m_fft_fwd_x_half.destroy();
405+
}
406+
407+
template <typename T, Direction D, DomainStrategy S>
408+
void R2C<T,D,S>::prepare_openbc ()
409+
{
410+
#if (AMREX_SPACEDIM == 3)
411+
if (m_slab_decomp) {
412+
auto* fab = detail::get_fab(m_rx);
413+
if (fab) {
414+
Box bottom_half = m_real_domain;
415+
bottom_half.growHi(2,-m_real_domain.length(2)/2);
416+
Box box = fab->box() & bottom_half;
417+
if (box.ok()) {
418+
auto* pr = fab->dataPtr();
419+
auto* pc = (typename Plan<T>::VendorComplex *)
420+
detail::get_fab(m_cx)->dataPtr();
421+
#ifdef AMREX_USE_SYCL
422+
m_fft_fwd_x_half.template init_r2c<Direction::forward>
423+
(box, pr, pc, m_slab_decomp);
424+
m_fft_bwd_x_half = m_fft_fwd_x_half;
425+
#else
426+
if constexpr (D == Direction::both || D == Direction::forward) {
427+
m_fft_fwd_x_half.template init_r2c<Direction::forward>
428+
(box, pr, pc, m_slab_decomp);
429+
}
430+
if constexpr (D == Direction::both || D == Direction::backward) {
431+
m_fft_bwd_x_half.template init_r2c<Direction::backward>
432+
(box, pr, pc, m_slab_decomp);
433+
}
434+
#endif
435+
}
436+
}
437+
} // else todo
438+
439+
if (m_cmd_x2z && ! m_cmd_x2z_half) {
440+
Box bottom_half = m_spectral_domain_z;
441+
// Note that z-direction's index is 0 because we z is the
442+
// unit-stride direction here.
443+
bottom_half.growHi(0,-m_spectral_domain_z.length(0)/2);
444+
m_cmd_x2z_half = std::make_unique<MultiBlockCommMetaData>
445+
(m_cz, bottom_half, m_cx, IntVect(0), m_dtos_x2z);
446+
}
447+
448+
if (m_cmd_z2x && ! m_cmd_z2x_half) {
449+
Box bottom_half = m_spectral_domain_x;
450+
bottom_half.growHi(2,-m_spectral_domain_x.length(2)/2);
451+
m_cmd_z2x_half = std::make_unique<MultiBlockCommMetaData>
452+
(m_cx, bottom_half, m_cz, IntVect(0), m_dtos_z2x);
453+
}
454+
#endif
397455
}
398456

399457
template <typename T, Direction D, DomainStrategy S>
@@ -406,7 +464,8 @@ void R2C<T,D,S>::forward (MF const& inmf)
406464
if (&m_rx != &inmf) {
407465
m_rx.ParallelCopy(inmf, 0, 0, 1);
408466
}
409-
m_fft_fwd_x.template compute_r2c<Direction::forward>();
467+
auto& fft_x = m_openbc_half ? m_fft_fwd_x_half : m_fft_fwd_x;
468+
fft_x.template compute_r2c<Direction::forward>();
410469

411470
if ( m_cmd_x2y) {
412471
ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y);
@@ -419,19 +478,16 @@ void R2C<T,D,S>::forward (MF const& inmf)
419478
#if (AMREX_SPACEDIM == 3)
420479
else if ( m_cmd_x2z) {
421480
if (m_openbc_half) {
422-
Box upper_half = m_spectral_domain_z;
423-
// Note that z-direction's index is 0 because we z is the unit-stride direction here.
424-
upper_half.growLo (0,-m_spectral_domain_z.length(0)/2);
425-
if (! m_cmd_x2z_half) {
426-
Box bottom_half = m_spectral_domain_z;
427-
bottom_half.growHi(0,-m_spectral_domain_z.length(0)/2);
428-
m_cmd_x2z_half = std::make_unique<MultiBlockCommMetaData>
429-
(m_cz, bottom_half, m_cx, IntVect(0), m_dtos_x2z);
430-
}
431481
NonLocalBC::ApplyDtosAndProjectionOnReciever packing
432482
{NonLocalBC::PackComponents{}, m_dtos_x2z};
433483
auto handler = ParallelCopy_nowait(m_cz, m_cx, *m_cmd_x2z_half, packing);
484+
485+
Box upper_half = m_spectral_domain_z;
486+
// Note that z-direction's index is 0 because we z is the
487+
// unit-stride direction here.
488+
upper_half.growLo (0,-m_spectral_domain_z.length(0)/2);
434489
m_cz.setVal(0, upper_half, 0, 1);
490+
435491
ParallelCopy_finish(m_cz, std::move(handler), *m_cmd_x2z_half, packing);
436492
} else {
437493
ParallelCopy(m_cz, m_cx, *m_cmd_x2z, 0, 0, 1, m_dtos_x2z);
@@ -459,22 +515,8 @@ void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)
459515
}
460516
#if (AMREX_SPACEDIM == 3)
461517
else if ( m_cmd_z2x) {
462-
if (m_openbc_half) {
463-
Box upper_half = m_spectral_domain_x;
464-
upper_half.growLo (2,-m_spectral_domain_x.length(2)/2);
465-
if (! m_cmd_z2x_half) {
466-
Box bottom_half = m_spectral_domain_x;
467-
bottom_half.growHi(2,-m_spectral_domain_x.length(2)/2);
468-
m_cmd_z2x_half = std::make_unique<MultiBlockCommMetaData>
469-
(m_cx, bottom_half, m_cz, IntVect(0), m_dtos_z2x);
470-
}
471-
NonLocalBC::ApplyDtosAndProjectionOnReciever packing
472-
{NonLocalBC::PackComponents{}, m_dtos_z2x};
473-
auto handler = ParallelCopy_nowait(m_cx, m_cz, *m_cmd_z2x_half, packing);
474-
ParallelCopy_finish(m_cx, std::move(handler), *m_cmd_z2x_half, packing);
475-
} else {
476-
ParallelCopy(m_cx, m_cz, *m_cmd_z2x, 0, 0, 1, m_dtos_z2x);
477-
}
518+
auto const& cmd = m_openbc_half ? m_cmd_z2x_half : m_cmd_z2x;
519+
ParallelCopy(m_cx, m_cz, *cmd, 0, 0, 1, m_dtos_z2x);
478520
}
479521
#endif
480522

@@ -483,7 +525,8 @@ void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)
483525
ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x);
484526
}
485527

486-
m_fft_bwd_x.template compute_r2c<Direction::backward>();
528+
auto& fft_x = m_openbc_half ? m_fft_bwd_x_half : m_fft_bwd_x;
529+
fft_x.template compute_r2c<Direction::backward>();
487530
outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout);
488531
}
489532

0 commit comments

Comments
 (0)