@@ -164,6 +164,8 @@ public:
164
164
template <typename F>
165
165
void post_forward_doit (F const & post_forward);
166
166
167
+ void prepare_openbc ();
168
+
167
169
private:
168
170
169
171
static std::pair<Plan<T>,Plan<T>> make_c2c_plans (cMF& inout);
@@ -176,6 +178,8 @@ private:
176
178
Plan<T> m_fft_bwd_y{};
177
179
Plan<T> m_fft_fwd_z{};
178
180
Plan<T> m_fft_bwd_z{};
181
+ Plan<T> m_fft_fwd_x_half{};
182
+ Plan<T> m_fft_bwd_x_half{};
179
183
180
184
// Comm meta-data. In the forward phase, we start with (x,y,z),
181
185
// transpose to (y,x,z) and then (z,x,y). In the backward phase, we
@@ -394,6 +398,60 @@ R2C<T,D,S>::~R2C<T,D,S> ()
394
398
m_fft_fwd_x.destroy ();
395
399
m_fft_fwd_y.destroy ();
396
400
m_fft_fwd_z.destroy ();
401
+ if (m_fft_bwd_x_half.plan != m_fft_fwd_x_half.plan ) {
402
+ m_fft_bwd_x_half.destroy ();
403
+ }
404
+ m_fft_fwd_x_half.destroy ();
405
+ }
406
+
407
+ template <typename T, Direction D, DomainStrategy S>
408
+ void R2C<T,D,S>::prepare_openbc ()
409
+ {
410
+ #if (AMREX_SPACEDIM == 3)
411
+ if (m_slab_decomp) {
412
+ auto * fab = detail::get_fab (m_rx);
413
+ if (fab) {
414
+ Box bottom_half = m_real_domain;
415
+ bottom_half.growHi (2 ,-m_real_domain.length (2 )/2 );
416
+ Box box = fab->box () & bottom_half;
417
+ if (box.ok ()) {
418
+ auto * pr = fab->dataPtr ();
419
+ auto * pc = (typename Plan<T>::VendorComplex *)
420
+ detail::get_fab (m_cx)->dataPtr ();
421
+ #ifdef AMREX_USE_SYCL
422
+ m_fft_fwd_x_half.template init_r2c <Direction::forward>
423
+ (box, pr, pc, m_slab_decomp);
424
+ m_fft_bwd_x_half = m_fft_fwd_x_half;
425
+ #else
426
+ if constexpr (D == Direction::both || D == Direction::forward) {
427
+ m_fft_fwd_x_half.template init_r2c <Direction::forward>
428
+ (box, pr, pc, m_slab_decomp);
429
+ }
430
+ if constexpr (D == Direction::both || D == Direction::backward) {
431
+ m_fft_bwd_x_half.template init_r2c <Direction::backward>
432
+ (box, pr, pc, m_slab_decomp);
433
+ }
434
+ #endif
435
+ }
436
+ }
437
+ } // else todo
438
+
439
+ if (m_cmd_x2z && ! m_cmd_x2z_half) {
440
+ Box bottom_half = m_spectral_domain_z;
441
+ // Note that z-direction's index is 0 because we z is the
442
+ // unit-stride direction here.
443
+ bottom_half.growHi (0 ,-m_spectral_domain_z.length (0 )/2 );
444
+ m_cmd_x2z_half = std::make_unique<MultiBlockCommMetaData>
445
+ (m_cz, bottom_half, m_cx, IntVect (0 ), m_dtos_x2z);
446
+ }
447
+
448
+ if (m_cmd_z2x && ! m_cmd_z2x_half) {
449
+ Box bottom_half = m_spectral_domain_x;
450
+ bottom_half.growHi (2 ,-m_spectral_domain_x.length (2 )/2 );
451
+ m_cmd_z2x_half = std::make_unique<MultiBlockCommMetaData>
452
+ (m_cx, bottom_half, m_cz, IntVect (0 ), m_dtos_z2x);
453
+ }
454
+ #endif
397
455
}
398
456
399
457
template <typename T, Direction D, DomainStrategy S>
@@ -406,7 +464,8 @@ void R2C<T,D,S>::forward (MF const& inmf)
406
464
if (&m_rx != &inmf) {
407
465
m_rx.ParallelCopy (inmf, 0 , 0 , 1 );
408
466
}
409
- m_fft_fwd_x.template compute_r2c <Direction::forward>();
467
+ auto & fft_x = m_openbc_half ? m_fft_fwd_x_half : m_fft_fwd_x;
468
+ fft_x.template compute_r2c <Direction::forward>();
410
469
411
470
if ( m_cmd_x2y) {
412
471
ParallelCopy (m_cy, m_cx, *m_cmd_x2y, 0 , 0 , 1 , m_dtos_x2y);
@@ -419,19 +478,16 @@ void R2C<T,D,S>::forward (MF const& inmf)
419
478
#if (AMREX_SPACEDIM == 3)
420
479
else if ( m_cmd_x2z) {
421
480
if (m_openbc_half) {
422
- Box upper_half = m_spectral_domain_z;
423
- // Note that z-direction's index is 0 because we z is the unit-stride direction here.
424
- upper_half.growLo (0 ,-m_spectral_domain_z.length (0 )/2 );
425
- if (! m_cmd_x2z_half) {
426
- Box bottom_half = m_spectral_domain_z;
427
- bottom_half.growHi (0 ,-m_spectral_domain_z.length (0 )/2 );
428
- m_cmd_x2z_half = std::make_unique<MultiBlockCommMetaData>
429
- (m_cz, bottom_half, m_cx, IntVect (0 ), m_dtos_x2z);
430
- }
431
481
NonLocalBC::ApplyDtosAndProjectionOnReciever packing
432
482
{NonLocalBC::PackComponents{}, m_dtos_x2z};
433
483
auto handler = ParallelCopy_nowait (m_cz, m_cx, *m_cmd_x2z_half, packing);
484
+
485
+ Box upper_half = m_spectral_domain_z;
486
+ // Note that z-direction's index is 0 because we z is the
487
+ // unit-stride direction here.
488
+ upper_half.growLo (0 ,-m_spectral_domain_z.length (0 )/2 );
434
489
m_cz.setVal (0 , upper_half, 0 , 1 );
490
+
435
491
ParallelCopy_finish (m_cz, std::move (handler), *m_cmd_x2z_half, packing);
436
492
} else {
437
493
ParallelCopy (m_cz, m_cx, *m_cmd_x2z, 0 , 0 , 1 , m_dtos_x2z);
@@ -459,22 +515,8 @@ void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)
459
515
}
460
516
#if (AMREX_SPACEDIM == 3)
461
517
else if ( m_cmd_z2x) {
462
- if (m_openbc_half) {
463
- Box upper_half = m_spectral_domain_x;
464
- upper_half.growLo (2 ,-m_spectral_domain_x.length (2 )/2 );
465
- if (! m_cmd_z2x_half) {
466
- Box bottom_half = m_spectral_domain_x;
467
- bottom_half.growHi (2 ,-m_spectral_domain_x.length (2 )/2 );
468
- m_cmd_z2x_half = std::make_unique<MultiBlockCommMetaData>
469
- (m_cx, bottom_half, m_cz, IntVect (0 ), m_dtos_z2x);
470
- }
471
- NonLocalBC::ApplyDtosAndProjectionOnReciever packing
472
- {NonLocalBC::PackComponents{}, m_dtos_z2x};
473
- auto handler = ParallelCopy_nowait (m_cx, m_cz, *m_cmd_z2x_half, packing);
474
- ParallelCopy_finish (m_cx, std::move (handler), *m_cmd_z2x_half, packing);
475
- } else {
476
- ParallelCopy (m_cx, m_cz, *m_cmd_z2x, 0 , 0 , 1 , m_dtos_z2x);
477
- }
518
+ auto const & cmd = m_openbc_half ? m_cmd_z2x_half : m_cmd_z2x;
519
+ ParallelCopy (m_cx, m_cz, *cmd, 0 , 0 , 1 , m_dtos_z2x);
478
520
}
479
521
#endif
480
522
@@ -483,7 +525,8 @@ void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)
483
525
ParallelCopy (m_cx, m_cy, *m_cmd_y2x, 0 , 0 , 1 , m_dtos_y2x);
484
526
}
485
527
486
- m_fft_bwd_x.template compute_r2c <Direction::backward>();
528
+ auto & fft_x = m_openbc_half ? m_fft_bwd_x_half : m_fft_bwd_x;
529
+ fft_x.template compute_r2c <Direction::backward>();
487
530
outmf.ParallelCopy (m_rx, 0 , 0 , 1 , IntVect (0 ), ngout);
488
531
}
489
532
0 commit comments